|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Script for training, evaluation and sampling for Real NVP. |
|
|
|
$ python real_nvp_multiscale_dataset.py \ |
|
--alsologtostderr \ |
|
--image_size 64 \ |
|
--hpconfig=n_scale=5,base_dim=8 \ |
|
--dataset imnet \ |
|
--data_path [DATA_PATH] |
|
""" |
|
|
|
from __future__ import print_function |
|
|
|
import time |
|
from datetime import datetime |
|
import os |
|
|
|
import numpy |
|
from six.moves import xrange |
|
import tensorflow as tf |
|
|
|
from tensorflow import gfile |
|
|
|
from real_nvp_utils import ( |
|
batch_norm, batch_norm_log_diff, conv_layer, |
|
squeeze_2x2, squeeze_2x2_ordered, standard_normal_ll, |
|
standard_normal_sample, unsqueeze_2x2, variable_on_cpu) |
|
|
|
|
|
tf.flags.DEFINE_string("master", "local", |
|
"BNS name of the TensorFlow master, or local.") |
|
|
|
tf.flags.DEFINE_string("logdir", "/tmp/real_nvp_multiscale", |
|
"Directory to which writes logs.") |
|
|
|
tf.flags.DEFINE_string("traindir", "/tmp/real_nvp_multiscale", |
|
"Directory to which writes logs.") |
|
|
|
tf.flags.DEFINE_integer("train_steps", 1000000000000000000, |
|
"Number of steps to train for.") |
|
|
|
tf.flags.DEFINE_string("data_path", "", "Path to the data.") |
|
|
|
tf.flags.DEFINE_string("mode", "train", |
|
"Mode of execution. Must be 'train', " |
|
"'sample' or 'eval'.") |
|
|
|
tf.flags.DEFINE_string("dataset", "imnet", |
|
"Dataset used. Must be 'imnet', " |
|
"'celeba' or 'lsun'.") |
|
|
|
tf.flags.DEFINE_integer("recursion_type", 2, |
|
"Type of the recursion.") |
|
|
|
tf.flags.DEFINE_integer("image_size", 64, |
|
"Size of the input image.") |
|
|
|
tf.flags.DEFINE_integer("eval_set_size", 0, |
|
"Size of evaluation dataset.") |
|
|
|
tf.flags.DEFINE_string( |
|
"hpconfig", "", |
|
"A comma separated list of hyperparameters for the model. Format is " |
|
"hp1=value1,hp2=value2,etc. If this FLAG is set, the model will be trained " |
|
"with the specified hyperparameters, filling in missing hyperparameters " |
|
"from the default_values in |hyper_params|.") |
|
|
|
FLAGS = tf.flags.FLAGS |
|
|
|
class HParams(object): |
|
"""Dictionary of hyperparameters.""" |
|
def __init__(self, **kwargs): |
|
self.dict_ = kwargs |
|
self.__dict__.update(self.dict_) |
|
|
|
def update_config(self, in_string): |
|
"""Update the dictionary with a comma separated list.""" |
|
pairs = in_string.split(",") |
|
pairs = [pair.split("=") for pair in pairs] |
|
for key, val in pairs: |
|
self.dict_[key] = type(self.dict_[key])(val) |
|
self.__dict__.update(self.dict_) |
|
return self |
|
|
|
def __getitem__(self, key): |
|
return self.dict_[key] |
|
|
|
def __setitem__(self, key, val): |
|
self.dict_[key] = val |
|
self.__dict__.update(self.dict_) |
|
|
|
|
|
def get_default_hparams(): |
|
"""Get the default hyperparameters.""" |
|
return HParams( |
|
batch_size=64, |
|
residual_blocks=2, |
|
n_couplings=2, |
|
n_scale=4, |
|
learning_rate=0.001, |
|
momentum=1e-1, |
|
decay=1e-3, |
|
l2_coeff=0.00005, |
|
clip_gradient=100., |
|
optimizer="adam", |
|
dropout_mask=0, |
|
base_dim=32, |
|
bottleneck=0, |
|
use_batch_norm=1, |
|
alternate=1, |
|
use_aff=1, |
|
skip=1, |
|
data_constraint=.9, |
|
n_opt=0) |
|
|
|
|
|
|
|
def residual_block(input_, dim, name, use_batch_norm=True, |
|
train=True, weight_norm=True, bottleneck=False): |
|
"""Residual convolutional block.""" |
|
with tf.variable_scope(name): |
|
res = input_ |
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=dim, name="bn_in", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.nn.relu(res) |
|
if bottleneck: |
|
res = conv_layer( |
|
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, |
|
name="h_0", stddev=numpy.sqrt(2. / (dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=(not use_batch_norm), |
|
weight_norm=weight_norm, scale=False) |
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=dim, |
|
name="bn_0", scale=False, train=train, |
|
epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.nn.relu(res) |
|
res = conv_layer( |
|
input_=res, filter_size=[3, 3], dim_in=dim, |
|
dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None, |
|
bias=(not use_batch_norm), |
|
weight_norm=weight_norm, scale=False) |
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=dim, name="bn_1", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.nn.relu(res) |
|
res = conv_layer( |
|
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, |
|
name="out", stddev=numpy.sqrt(2. / (1. * dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None, |
|
bias=True, weight_norm=weight_norm, scale=True) |
|
else: |
|
res = conv_layer( |
|
input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim, |
|
name="h_0", stddev=numpy.sqrt(2. / (dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=(not use_batch_norm), |
|
weight_norm=weight_norm, scale=False) |
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=dim, name="bn_0", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.nn.relu(res) |
|
res = conv_layer( |
|
input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim, |
|
name="out", stddev=numpy.sqrt(2. / (1. * dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None, |
|
bias=True, weight_norm=weight_norm, scale=True) |
|
res += input_ |
|
|
|
return res |
|
|
|
|
|
def resnet(input_, dim_in, dim, dim_out, name, use_batch_norm=True, |
|
train=True, weight_norm=True, residual_blocks=5, |
|
bottleneck=False, skip=True): |
|
"""Residual convolutional network.""" |
|
with tf.variable_scope(name): |
|
res = input_ |
|
if residual_blocks != 0: |
|
res = conv_layer( |
|
input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim, |
|
name="h_in", stddev=numpy.sqrt(2. / (dim_in)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=True, |
|
weight_norm=weight_norm, scale=False) |
|
if skip: |
|
out = conv_layer( |
|
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, |
|
name="skip_in", stddev=numpy.sqrt(2. / (dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=True, |
|
weight_norm=weight_norm, scale=True) |
|
|
|
|
|
for idx_block in xrange(residual_blocks): |
|
res = residual_block(res, dim, "block_%d" % idx_block, |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
bottleneck=bottleneck) |
|
if skip: |
|
out += conv_layer( |
|
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, |
|
name="skip_%d" % idx_block, stddev=numpy.sqrt(2. / (dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=True, |
|
weight_norm=weight_norm, scale=True) |
|
|
|
if skip: |
|
res = out |
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=dim, name="bn_pre_out", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.nn.relu(res) |
|
res = conv_layer( |
|
input_=res, filter_size=[1, 1], dim_in=dim, |
|
dim_out=dim_out, |
|
name="out", stddev=numpy.sqrt(2. / (1. * dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=True, |
|
weight_norm=weight_norm, scale=True) |
|
else: |
|
if bottleneck: |
|
res = conv_layer( |
|
input_=res, filter_size=[1, 1], dim_in=dim_in, dim_out=dim, |
|
name="h_0", stddev=numpy.sqrt(2. / (dim_in)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=(not use_batch_norm), |
|
weight_norm=weight_norm, scale=False) |
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=dim, name="bn_0", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.nn.relu(res) |
|
res = conv_layer( |
|
input_=res, filter_size=[3, 3], dim_in=dim, |
|
dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, |
|
bias=(not use_batch_norm), |
|
weight_norm=weight_norm, scale=False) |
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=dim, name="bn_1", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.nn.relu(res) |
|
res = conv_layer( |
|
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim_out, |
|
name="out", stddev=numpy.sqrt(2. / (1. * dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=True, |
|
weight_norm=weight_norm, scale=True) |
|
else: |
|
res = conv_layer( |
|
input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim, |
|
name="h_0", stddev=numpy.sqrt(2. / (dim_in)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=(not use_batch_norm), |
|
weight_norm=weight_norm, scale=False) |
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=dim, name="bn_0", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.nn.relu(res) |
|
res = conv_layer( |
|
input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim_out, |
|
name="out", stddev=numpy.sqrt(2. / (1. * dim)), |
|
strides=[1, 1, 1, 1], padding="SAME", |
|
nonlinearity=None, bias=True, |
|
weight_norm=weight_norm, scale=True) |
|
return res |
|
|
|
|
|
|
|
|
|
def masked_conv_aff_coupling(input_, mask_in, dim, name, |
|
use_batch_norm=True, train=True, weight_norm=True, |
|
reverse=False, residual_blocks=5, |
|
bottleneck=False, use_width=1., use_height=1., |
|
mask_channel=0., skip=True): |
|
"""Affine coupling with masked convolution.""" |
|
with tf.variable_scope(name) as scope: |
|
if reverse or (not train): |
|
scope.reuse_variables() |
|
shape = input_.get_shape().as_list() |
|
batch_size = shape[0] |
|
height = shape[1] |
|
width = shape[2] |
|
channels = shape[3] |
|
|
|
|
|
mask = use_width * numpy.arange(width) |
|
mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask |
|
mask = mask.astype("float32") |
|
mask = tf.mod(mask_in + mask, 2) |
|
mask = tf.reshape(mask, [-1, height, width, 1]) |
|
if mask.get_shape().as_list()[0] == 1: |
|
mask = tf.tile(mask, [batch_size, 1, 1, 1]) |
|
res = input_ * tf.mod(mask_channel + mask, 2) |
|
|
|
|
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=channels, name="bn_in", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res *= 2. |
|
res = tf.concat([res, -res], 3) |
|
res = tf.concat([res, mask], 3) |
|
dim_in = 2. * channels + 1 |
|
res = tf.nn.relu(res) |
|
res = resnet(input_=res, dim_in=dim_in, dim=dim, |
|
dim_out=2 * channels, |
|
name="resnet", use_batch_norm=use_batch_norm, |
|
train=train, weight_norm=weight_norm, |
|
residual_blocks=residual_blocks, |
|
bottleneck=bottleneck, skip=skip) |
|
mask = tf.mod(mask_channel + mask, 2) |
|
res = tf.split(axis=3, num_or_size_splits=2, value=res) |
|
shift, log_rescaling = res[-2], res[-1] |
|
scale = variable_on_cpu( |
|
"rescaling_scale", [], |
|
tf.constant_initializer(0.)) |
|
shift = tf.reshape( |
|
shift, [batch_size, height, width, channels]) |
|
log_rescaling = tf.reshape( |
|
log_rescaling, [batch_size, height, width, channels]) |
|
log_rescaling = scale * tf.tanh(log_rescaling) |
|
if not use_batch_norm: |
|
scale_shift = variable_on_cpu( |
|
"scale_shift", [], |
|
tf.constant_initializer(0.)) |
|
log_rescaling += scale_shift |
|
shift *= (1. - mask) |
|
log_rescaling *= (1. - mask) |
|
if reverse: |
|
res = input_ |
|
if use_batch_norm: |
|
mean, var = batch_norm_log_diff( |
|
input_=res * (1. - mask), dim=channels, name="bn_out", |
|
train=False, epsilon=1e-4, axes=[0, 1, 2]) |
|
log_var = tf.log(var) |
|
res *= tf.exp(.5 * log_var * (1. - mask)) |
|
res += mean * (1. - mask) |
|
res *= tf.exp(-log_rescaling) |
|
res -= shift |
|
log_diff = -log_rescaling |
|
if use_batch_norm: |
|
log_diff += .5 * log_var * (1. - mask) |
|
else: |
|
res = input_ |
|
res += shift |
|
res *= tf.exp(log_rescaling) |
|
log_diff = log_rescaling |
|
if use_batch_norm: |
|
mean, var = batch_norm_log_diff( |
|
input_=res * (1. - mask), dim=channels, name="bn_out", |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
log_var = tf.log(var) |
|
res -= mean * (1. - mask) |
|
res *= tf.exp(-.5 * log_var * (1. - mask)) |
|
log_diff -= .5 * log_var * (1. - mask) |
|
|
|
return res, log_diff |
|
|
|
|
|
def masked_conv_add_coupling(input_, mask_in, dim, name, |
|
use_batch_norm=True, train=True, weight_norm=True, |
|
reverse=False, residual_blocks=5, |
|
bottleneck=False, use_width=1., use_height=1., |
|
mask_channel=0., skip=True): |
|
"""Additive coupling with masked convolution.""" |
|
with tf.variable_scope(name) as scope: |
|
if reverse or (not train): |
|
scope.reuse_variables() |
|
shape = input_.get_shape().as_list() |
|
batch_size = shape[0] |
|
height = shape[1] |
|
width = shape[2] |
|
channels = shape[3] |
|
|
|
|
|
mask = use_width * numpy.arange(width) |
|
mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask |
|
mask = mask.astype("float32") |
|
mask = tf.mod(mask_in + mask, 2) |
|
mask = tf.reshape(mask, [-1, height, width, 1]) |
|
if mask.get_shape().as_list()[0] == 1: |
|
mask = tf.tile(mask, [batch_size, 1, 1, 1]) |
|
res = input_ * tf.mod(mask_channel + mask, 2) |
|
|
|
|
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=channels, name="bn_in", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res *= 2. |
|
res = tf.concat([res, -res], 3) |
|
res = tf.concat([res, mask], 3) |
|
dim_in = 2. * channels + 1 |
|
res = tf.nn.relu(res) |
|
shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels, |
|
name="resnet", use_batch_norm=use_batch_norm, |
|
train=train, weight_norm=weight_norm, |
|
residual_blocks=residual_blocks, |
|
bottleneck=bottleneck, skip=skip) |
|
mask = tf.mod(mask_channel + mask, 2) |
|
shift *= (1. - mask) |
|
|
|
if reverse: |
|
res = input_ |
|
if use_batch_norm: |
|
mean, var = batch_norm_log_diff( |
|
input_=res * (1. - mask), |
|
dim=channels, name="bn_out", train=False, epsilon=1e-4) |
|
log_var = tf.log(var) |
|
res *= tf.exp(.5 * log_var * (1. - mask)) |
|
res += mean * (1. - mask) |
|
res -= shift |
|
log_diff = tf.zeros_like(res) |
|
if use_batch_norm: |
|
log_diff += .5 * log_var * (1. - mask) |
|
else: |
|
res = input_ |
|
res += shift |
|
log_diff = tf.zeros_like(res) |
|
if use_batch_norm: |
|
mean, var = batch_norm_log_diff( |
|
input_=res * (1. - mask), dim=channels, |
|
name="bn_out", train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
log_var = tf.log(var) |
|
res -= mean * (1. - mask) |
|
res *= tf.exp(-.5 * log_var * (1. - mask)) |
|
log_diff -= .5 * log_var * (1. - mask) |
|
|
|
return res, log_diff |
|
|
|
|
|
def masked_conv_coupling(input_, mask_in, dim, name, |
|
use_batch_norm=True, train=True, weight_norm=True, |
|
reverse=False, residual_blocks=5, |
|
bottleneck=False, use_aff=True, |
|
use_width=1., use_height=1., |
|
mask_channel=0., skip=True): |
|
"""Coupling with masked convolution.""" |
|
if use_aff: |
|
return masked_conv_aff_coupling( |
|
input_=input_, mask_in=mask_in, dim=dim, name=name, |
|
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, |
|
reverse=reverse, residual_blocks=residual_blocks, |
|
bottleneck=bottleneck, use_width=use_width, use_height=use_height, |
|
mask_channel=mask_channel, skip=skip) |
|
else: |
|
return masked_conv_add_coupling( |
|
input_=input_, mask_in=mask_in, dim=dim, name=name, |
|
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, |
|
reverse=reverse, residual_blocks=residual_blocks, |
|
bottleneck=bottleneck, use_width=use_width, use_height=use_height, |
|
mask_channel=mask_channel, skip=skip) |
|
|
|
|
|
|
|
def conv_ch_aff_coupling(input_, dim, name, |
|
use_batch_norm=True, train=True, weight_norm=True, |
|
reverse=False, residual_blocks=5, |
|
bottleneck=False, change_bottom=True, skip=True): |
|
"""Affine coupling with channel-wise splitting.""" |
|
with tf.variable_scope(name) as scope: |
|
if reverse or (not train): |
|
scope.reuse_variables() |
|
|
|
if change_bottom: |
|
input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_) |
|
else: |
|
canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_) |
|
shape = input_.get_shape().as_list() |
|
batch_size = shape[0] |
|
height = shape[1] |
|
width = shape[2] |
|
channels = shape[3] |
|
res = input_ |
|
|
|
|
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=channels, name="bn_in", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.concat([res, -res], 3) |
|
dim_in = 2. * channels |
|
res = tf.nn.relu(res) |
|
res = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=2 * channels, |
|
name="resnet", use_batch_norm=use_batch_norm, |
|
train=train, weight_norm=weight_norm, |
|
residual_blocks=residual_blocks, |
|
bottleneck=bottleneck, skip=skip) |
|
shift, log_rescaling = tf.split(axis=3, num_or_size_splits=2, value=res) |
|
scale = variable_on_cpu( |
|
"scale", [], |
|
tf.constant_initializer(1.)) |
|
shift = tf.reshape( |
|
shift, [batch_size, height, width, channels]) |
|
log_rescaling = tf.reshape( |
|
log_rescaling, [batch_size, height, width, channels]) |
|
log_rescaling = scale * tf.tanh(log_rescaling) |
|
if not use_batch_norm: |
|
scale_shift = variable_on_cpu( |
|
"scale_shift", [], |
|
tf.constant_initializer(0.)) |
|
log_rescaling += scale_shift |
|
if reverse: |
|
res = canvas |
|
if use_batch_norm: |
|
mean, var = batch_norm_log_diff( |
|
input_=res, dim=channels, name="bn_out", train=False, |
|
epsilon=1e-4, axes=[0, 1, 2]) |
|
log_var = tf.log(var) |
|
res *= tf.exp(.5 * log_var) |
|
res += mean |
|
res *= tf.exp(-log_rescaling) |
|
res -= shift |
|
log_diff = -log_rescaling |
|
if use_batch_norm: |
|
log_diff += .5 * log_var |
|
else: |
|
res = canvas |
|
res += shift |
|
res *= tf.exp(log_rescaling) |
|
log_diff = log_rescaling |
|
if use_batch_norm: |
|
mean, var = batch_norm_log_diff( |
|
input_=res, dim=channels, name="bn_out", train=train, |
|
epsilon=1e-4, axes=[0, 1, 2]) |
|
log_var = tf.log(var) |
|
res -= mean |
|
res *= tf.exp(-.5 * log_var) |
|
log_diff -= .5 * log_var |
|
if change_bottom: |
|
res = tf.concat([input_, res], 3) |
|
log_diff = tf.concat([tf.zeros_like(log_diff), log_diff], 3) |
|
else: |
|
res = tf.concat([res, input_], 3) |
|
log_diff = tf.concat([log_diff, tf.zeros_like(log_diff)], 3) |
|
|
|
return res, log_diff |
|
|
|
|
|
def conv_ch_add_coupling(input_, dim, name, |
|
use_batch_norm=True, train=True, weight_norm=True, |
|
reverse=False, residual_blocks=5, |
|
bottleneck=False, change_bottom=True, skip=True): |
|
"""Additive coupling with channel-wise splitting.""" |
|
with tf.variable_scope(name) as scope: |
|
if reverse or (not train): |
|
scope.reuse_variables() |
|
|
|
if change_bottom: |
|
input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_) |
|
else: |
|
canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_) |
|
shape = input_.get_shape().as_list() |
|
channels = shape[3] |
|
res = input_ |
|
|
|
|
|
if use_batch_norm: |
|
res = batch_norm( |
|
input_=res, dim=channels, name="bn_in", scale=False, |
|
train=train, epsilon=1e-4, axes=[0, 1, 2]) |
|
res = tf.concat([res, -res], 3) |
|
dim_in = 2. * channels |
|
res = tf.nn.relu(res) |
|
shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels, |
|
name="resnet", use_batch_norm=use_batch_norm, |
|
train=train, weight_norm=weight_norm, |
|
residual_blocks=residual_blocks, |
|
bottleneck=bottleneck, skip=skip) |
|
if reverse: |
|
res = canvas |
|
if use_batch_norm: |
|
mean, var = batch_norm_log_diff( |
|
input_=res, dim=channels, name="bn_out", train=False, |
|
epsilon=1e-4, axes=[0, 1, 2]) |
|
log_var = tf.log(var) |
|
res *= tf.exp(.5 * log_var) |
|
res += mean |
|
res -= shift |
|
log_diff = tf.zeros_like(res) |
|
if use_batch_norm: |
|
log_diff += .5 * log_var |
|
else: |
|
res = canvas |
|
res += shift |
|
log_diff = tf.zeros_like(res) |
|
if use_batch_norm: |
|
mean, var = batch_norm_log_diff( |
|
input_=res, dim=channels, name="bn_out", train=train, |
|
epsilon=1e-4, axes=[0, 1, 2]) |
|
log_var = tf.log(var) |
|
res -= mean |
|
res *= tf.exp(-.5 * log_var) |
|
log_diff -= .5 * log_var |
|
if change_bottom: |
|
res = tf.concat([input_, res], 3) |
|
log_diff = tf.concat([tf.zeros_like(log_diff), log_diff], 3) |
|
else: |
|
res = tf.concat([res, input_], 3) |
|
log_diff = tf.concat([log_diff, tf.zeros_like(log_diff)], 3) |
|
|
|
return res, log_diff |
|
|
|
|
|
def conv_ch_coupling(input_, dim, name, |
|
use_batch_norm=True, train=True, weight_norm=True, |
|
reverse=False, residual_blocks=5, |
|
bottleneck=False, use_aff=True, change_bottom=True, |
|
skip=True): |
|
"""Coupling with channel-wise splitting.""" |
|
if use_aff: |
|
return conv_ch_aff_coupling( |
|
input_=input_, dim=dim, name=name, |
|
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, |
|
reverse=reverse, residual_blocks=residual_blocks, |
|
bottleneck=bottleneck, change_bottom=change_bottom, skip=skip) |
|
else: |
|
return conv_ch_add_coupling( |
|
input_=input_, dim=dim, name=name, |
|
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, |
|
reverse=reverse, residual_blocks=residual_blocks, |
|
bottleneck=bottleneck, change_bottom=change_bottom, skip=skip) |
|
|
|
|
|
|
|
def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale, |
|
use_batch_norm=True, weight_norm=True, |
|
train=True): |
|
"""Recursion on coupling layers.""" |
|
shape = input_.get_shape().as_list() |
|
channels = shape[3] |
|
residual_blocks = hps.residual_blocks |
|
base_dim = hps.base_dim |
|
mask = 1. |
|
use_aff = hps.use_aff |
|
res = input_ |
|
skip = hps.skip |
|
log_diff = tf.zeros_like(input_) |
|
dim = base_dim |
|
if FLAGS.recursion_type < 4: |
|
dim *= 2 ** scale_idx |
|
with tf.variable_scope("scale_%d" % scale_idx): |
|
|
|
res, inc_log_diff = masked_conv_coupling( |
|
input_=res, |
|
mask_in=mask, dim=dim, |
|
name="coupling_0", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=False, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=use_aff, |
|
use_width=1., use_height=1., skip=skip) |
|
log_diff += inc_log_diff |
|
res, inc_log_diff = masked_conv_coupling( |
|
input_=res, |
|
mask_in=1. - mask, dim=dim, |
|
name="coupling_1", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=False, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=use_aff, |
|
use_width=1., use_height=1., skip=skip) |
|
log_diff += inc_log_diff |
|
res, inc_log_diff = masked_conv_coupling( |
|
input_=res, |
|
mask_in=mask, dim=dim, |
|
name="coupling_2", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=False, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=True, |
|
use_width=1., use_height=1., skip=skip) |
|
log_diff += inc_log_diff |
|
if scale_idx < (n_scale - 1): |
|
with tf.variable_scope("scale_%d" % scale_idx): |
|
res = squeeze_2x2(res) |
|
log_diff = squeeze_2x2(log_diff) |
|
res, inc_log_diff = conv_ch_coupling( |
|
input_=res, |
|
change_bottom=True, dim=2 * dim, |
|
name="coupling_4", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=False, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) |
|
log_diff += inc_log_diff |
|
res, inc_log_diff = conv_ch_coupling( |
|
input_=res, |
|
change_bottom=False, dim=2 * dim, |
|
name="coupling_5", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=False, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) |
|
log_diff += inc_log_diff |
|
res, inc_log_diff = conv_ch_coupling( |
|
input_=res, |
|
change_bottom=True, dim=2 * dim, |
|
name="coupling_6", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=False, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=True, skip=skip) |
|
log_diff += inc_log_diff |
|
res = unsqueeze_2x2(res) |
|
log_diff = unsqueeze_2x2(log_diff) |
|
if FLAGS.recursion_type > 1: |
|
res = squeeze_2x2_ordered(res) |
|
log_diff = squeeze_2x2_ordered(log_diff) |
|
if FLAGS.recursion_type > 2: |
|
res_1 = res[:, :, :, :channels] |
|
res_2 = res[:, :, :, channels:] |
|
log_diff_1 = log_diff[:, :, :, :channels] |
|
log_diff_2 = log_diff[:, :, :, channels:] |
|
else: |
|
res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res) |
|
log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff) |
|
res_1, inc_log_diff = rec_masked_conv_coupling( |
|
input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale, |
|
use_batch_norm=use_batch_norm, weight_norm=weight_norm, |
|
train=train) |
|
res = tf.concat([res_1, res_2], 3) |
|
log_diff_1 += inc_log_diff |
|
log_diff = tf.concat([log_diff_1, log_diff_2], 3) |
|
res = squeeze_2x2_ordered(res, reverse=True) |
|
log_diff = squeeze_2x2_ordered(log_diff, reverse=True) |
|
else: |
|
res = squeeze_2x2_ordered(res) |
|
log_diff = squeeze_2x2_ordered(log_diff) |
|
res, inc_log_diff = rec_masked_conv_coupling( |
|
input_=res, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale, |
|
use_batch_norm=use_batch_norm, weight_norm=weight_norm, |
|
train=train) |
|
log_diff += inc_log_diff |
|
res = squeeze_2x2_ordered(res, reverse=True) |
|
log_diff = squeeze_2x2_ordered(log_diff, reverse=True) |
|
else: |
|
with tf.variable_scope("scale_%d" % scale_idx): |
|
res, inc_log_diff = masked_conv_coupling( |
|
input_=res, |
|
mask_in=1. - mask, dim=dim, |
|
name="coupling_3", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=False, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=True, |
|
use_width=1., use_height=1., skip=skip) |
|
log_diff += inc_log_diff |
|
return res, log_diff |
|
|
|
|
|
def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale, |
|
use_batch_norm=True, weight_norm=True, |
|
train=True): |
|
"""Recursion on inverting coupling layers.""" |
|
shape = input_.get_shape().as_list() |
|
channels = shape[3] |
|
residual_blocks = hps.residual_blocks |
|
base_dim = hps.base_dim |
|
mask = 1. |
|
use_aff = hps.use_aff |
|
res = input_ |
|
log_diff = tf.zeros_like(input_) |
|
skip = hps.skip |
|
dim = base_dim |
|
if FLAGS.recursion_type < 4: |
|
dim *= 2 ** scale_idx |
|
if scale_idx < (n_scale - 1): |
|
if FLAGS.recursion_type > 1: |
|
res = squeeze_2x2_ordered(res) |
|
log_diff = squeeze_2x2_ordered(log_diff) |
|
if FLAGS.recursion_type > 2: |
|
res_1 = res[:, :, :, :channels] |
|
res_2 = res[:, :, :, channels:] |
|
log_diff_1 = log_diff[:, :, :, :channels] |
|
log_diff_2 = log_diff[:, :, :, channels:] |
|
else: |
|
res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res) |
|
log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff) |
|
res_1, log_diff_1 = rec_masked_deconv_coupling( |
|
input_=res_1, hps=hps, |
|
scale_idx=scale_idx + 1, n_scale=n_scale, |
|
use_batch_norm=use_batch_norm, weight_norm=weight_norm, |
|
train=train) |
|
res = tf.concat([res_1, res_2], 3) |
|
log_diff = tf.concat([log_diff_1, log_diff_2], 3) |
|
res = squeeze_2x2_ordered(res, reverse=True) |
|
log_diff = squeeze_2x2_ordered(log_diff, reverse=True) |
|
else: |
|
res = squeeze_2x2_ordered(res) |
|
log_diff = squeeze_2x2_ordered(log_diff) |
|
res, log_diff = rec_masked_deconv_coupling( |
|
input_=res, hps=hps, |
|
scale_idx=scale_idx + 1, n_scale=n_scale, |
|
use_batch_norm=use_batch_norm, weight_norm=weight_norm, |
|
train=train) |
|
res = squeeze_2x2_ordered(res, reverse=True) |
|
log_diff = squeeze_2x2_ordered(log_diff, reverse=True) |
|
with tf.variable_scope("scale_%d" % scale_idx): |
|
res = squeeze_2x2(res) |
|
log_diff = squeeze_2x2(log_diff) |
|
res, inc_log_diff = conv_ch_coupling( |
|
input_=res, |
|
change_bottom=True, dim=2 * dim, |
|
name="coupling_6", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=True, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=True, skip=skip) |
|
log_diff += inc_log_diff |
|
res, inc_log_diff = conv_ch_coupling( |
|
input_=res, |
|
change_bottom=False, dim=2 * dim, |
|
name="coupling_5", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=True, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) |
|
log_diff += inc_log_diff |
|
res, inc_log_diff = conv_ch_coupling( |
|
input_=res, |
|
change_bottom=True, dim=2 * dim, |
|
name="coupling_4", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=True, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) |
|
log_diff += inc_log_diff |
|
res = unsqueeze_2x2(res) |
|
log_diff = unsqueeze_2x2(log_diff) |
|
else: |
|
with tf.variable_scope("scale_%d" % scale_idx): |
|
res, inc_log_diff = masked_conv_coupling( |
|
input_=res, |
|
mask_in=1. - mask, dim=dim, |
|
name="coupling_3", |
|
use_batch_norm=use_batch_norm, train=train, |
|
weight_norm=weight_norm, |
|
reverse=True, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=True, |
|
use_width=1., use_height=1., skip=skip) |
|
log_diff += inc_log_diff |
|
|
|
with tf.variable_scope("scale_%d" % scale_idx): |
|
res, inc_log_diff = masked_conv_coupling( |
|
input_=res, |
|
mask_in=mask, dim=dim, |
|
name="coupling_2", |
|
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, |
|
reverse=True, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=True, |
|
use_width=1., use_height=1., skip=skip) |
|
log_diff += inc_log_diff |
|
res, inc_log_diff = masked_conv_coupling( |
|
input_=res, |
|
mask_in=1. - mask, dim=dim, |
|
name="coupling_1", |
|
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, |
|
reverse=True, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=use_aff, |
|
use_width=1., use_height=1., skip=skip) |
|
log_diff += inc_log_diff |
|
res, inc_log_diff = masked_conv_coupling( |
|
input_=res, |
|
mask_in=mask, dim=dim, |
|
name="coupling_0", |
|
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, |
|
reverse=True, residual_blocks=residual_blocks, |
|
bottleneck=hps.bottleneck, use_aff=use_aff, |
|
use_width=1., use_height=1., skip=skip) |
|
log_diff += inc_log_diff |
|
|
|
return res, log_diff |
|
|
|
|
|
|
|
|
|
def encoder(input_, hps, n_scale, use_batch_norm=True, |
|
weight_norm=True, train=True): |
|
"""Encoding/gaussianization function.""" |
|
res = input_ |
|
log_diff = tf.zeros_like(input_) |
|
res, inc_log_diff = rec_masked_conv_coupling( |
|
input_=res, hps=hps, scale_idx=0, n_scale=n_scale, |
|
use_batch_norm=use_batch_norm, weight_norm=weight_norm, |
|
train=train) |
|
log_diff += inc_log_diff |
|
|
|
return res, log_diff |
|
|
|
|
|
def decoder(input_, hps, n_scale, use_batch_norm=True, |
|
weight_norm=True, train=True): |
|
"""Decoding/generator function.""" |
|
res, log_diff = rec_masked_deconv_coupling( |
|
input_=input_, hps=hps, scale_idx=0, n_scale=n_scale, |
|
use_batch_norm=use_batch_norm, weight_norm=weight_norm, |
|
train=train) |
|
|
|
return res, log_diff |
|
|
|
|
|
class RealNVP(object): |
|
"""Real NVP model.""" |
|
|
|
def __init__(self, hps, sampling=False): |
|
|
|
device = "/cpu:0" |
|
if FLAGS.dataset == "imnet": |
|
with tf.device( |
|
tf.train.replica_device_setter(0, worker_device=device)): |
|
filename_queue = tf.train.string_input_producer( |
|
gfile.Glob(FLAGS.data_path), num_epochs=None) |
|
reader = tf.TFRecordReader() |
|
_, serialized_example = reader.read(filename_queue) |
|
features = tf.parse_single_example( |
|
serialized_example, |
|
features={ |
|
"image_raw": tf.FixedLenFeature([], tf.string), |
|
}) |
|
image = tf.decode_raw(features["image_raw"], tf.uint8) |
|
image.set_shape([FLAGS.image_size * FLAGS.image_size * 3]) |
|
image = tf.cast(image, tf.float32) |
|
if FLAGS.mode == "train": |
|
images = tf.train.shuffle_batch( |
|
[image], batch_size=hps.batch_size, num_threads=1, |
|
capacity=1000 + 3 * hps.batch_size, |
|
|
|
min_after_dequeue=1000) |
|
else: |
|
images = tf.train.batch( |
|
[image], batch_size=hps.batch_size, num_threads=1, |
|
capacity=1000 + 3 * hps.batch_size) |
|
self.x_orig = x_orig = images |
|
image_size = FLAGS.image_size |
|
x_in = tf.reshape( |
|
x_orig, |
|
[hps.batch_size, FLAGS.image_size, FLAGS.image_size, 3]) |
|
x_in = tf.clip_by_value(x_in, 0, 255) |
|
x_in = (tf.cast(x_in, tf.float32) |
|
+ tf.random_uniform(tf.shape(x_in))) / 256. |
|
elif FLAGS.dataset == "celeba": |
|
with tf.device( |
|
tf.train.replica_device_setter(0, worker_device=device)): |
|
filename_queue = tf.train.string_input_producer( |
|
gfile.Glob(FLAGS.data_path), num_epochs=None) |
|
reader = tf.TFRecordReader() |
|
_, serialized_example = reader.read(filename_queue) |
|
features = tf.parse_single_example( |
|
serialized_example, |
|
features={ |
|
"image_raw": tf.FixedLenFeature([], tf.string), |
|
}) |
|
image = tf.decode_raw(features["image_raw"], tf.uint8) |
|
image.set_shape([218 * 178 * 3]) |
|
image = tf.cast(image, tf.float32) |
|
image = tf.reshape(image, [218, 178, 3]) |
|
image = image[40:188, 15:163, :] |
|
if FLAGS.mode == "train": |
|
image = tf.image.random_flip_left_right(image) |
|
images = tf.train.shuffle_batch( |
|
[image], batch_size=hps.batch_size, num_threads=1, |
|
capacity=1000 + 3 * hps.batch_size, |
|
min_after_dequeue=1000) |
|
else: |
|
images = tf.train.batch( |
|
[image], batch_size=hps.batch_size, num_threads=1, |
|
capacity=1000 + 3 * hps.batch_size) |
|
self.x_orig = x_orig = images |
|
image_size = 64 |
|
x_in = tf.reshape(x_orig, [hps.batch_size, 148, 148, 3]) |
|
x_in = tf.image.resize_images( |
|
x_in, [64, 64], method=0, align_corners=False) |
|
x_in = (tf.cast(x_in, tf.float32) |
|
+ tf.random_uniform(tf.shape(x_in))) / 256. |
|
elif FLAGS.dataset == "lsun": |
|
with tf.device( |
|
tf.train.replica_device_setter(0, worker_device=device)): |
|
filename_queue = tf.train.string_input_producer( |
|
gfile.Glob(FLAGS.data_path), num_epochs=None) |
|
reader = tf.TFRecordReader() |
|
_, serialized_example = reader.read(filename_queue) |
|
features = tf.parse_single_example( |
|
serialized_example, |
|
features={ |
|
"image_raw": tf.FixedLenFeature([], tf.string), |
|
"height": tf.FixedLenFeature([], tf.int64), |
|
"width": tf.FixedLenFeature([], tf.int64), |
|
"depth": tf.FixedLenFeature([], tf.int64) |
|
}) |
|
image = tf.decode_raw(features["image_raw"], tf.uint8) |
|
height = tf.reshape((features["height"], tf.int64)[0], [1]) |
|
height = tf.cast(height, tf.int32) |
|
width = tf.reshape((features["width"], tf.int64)[0], [1]) |
|
width = tf.cast(width, tf.int32) |
|
depth = tf.reshape((features["depth"], tf.int64)[0], [1]) |
|
depth = tf.cast(depth, tf.int32) |
|
image = tf.reshape(image, tf.concat([height, width, depth], 0)) |
|
image = tf.random_crop(image, [64, 64, 3]) |
|
if FLAGS.mode == "train": |
|
image = tf.image.random_flip_left_right(image) |
|
images = tf.train.shuffle_batch( |
|
[image], batch_size=hps.batch_size, num_threads=1, |
|
capacity=1000 + 3 * hps.batch_size, |
|
|
|
min_after_dequeue=1000) |
|
else: |
|
images = tf.train.batch( |
|
[image], batch_size=hps.batch_size, num_threads=1, |
|
capacity=1000 + 3 * hps.batch_size) |
|
self.x_orig = x_orig = images |
|
image_size = 64 |
|
x_in = tf.reshape(x_orig, [hps.batch_size, 64, 64, 3]) |
|
x_in = (tf.cast(x_in, tf.float32) |
|
+ tf.random_uniform(tf.shape(x_in))) / 256. |
|
else: |
|
raise ValueError("Unknown dataset.") |
|
x_in = tf.reshape(x_in, [hps.batch_size, image_size, image_size, 3]) |
|
side_shown = int(numpy.sqrt(hps.batch_size)) |
|
shown_x = tf.transpose( |
|
tf.reshape( |
|
x_in[:(side_shown * side_shown), :, :, :], |
|
[side_shown, image_size * side_shown, image_size, 3]), |
|
[0, 2, 1, 3]) |
|
shown_x = tf.transpose( |
|
tf.reshape( |
|
shown_x, |
|
[1, image_size * side_shown, image_size * side_shown, 3]), |
|
[0, 2, 1, 3]) * 255. |
|
tf.summary.image( |
|
"inputs", |
|
tf.cast(shown_x, tf.uint8), |
|
max_outputs=1) |
|
|
|
|
|
FLAGS.image_size = image_size |
|
data_constraint = hps.data_constraint |
|
pre_logit_scale = numpy.log(data_constraint) |
|
pre_logit_scale -= numpy.log(1. - data_constraint) |
|
pre_logit_scale = tf.cast(pre_logit_scale, tf.float32) |
|
logit_x_in = 2. * x_in |
|
logit_x_in -= 1. |
|
logit_x_in *= data_constraint |
|
logit_x_in += 1. |
|
logit_x_in /= 2. |
|
|
|
logit_x_in = tf.log(logit_x_in) - tf.log(1. - logit_x_in) |
|
transform_cost = tf.reduce_sum( |
|
tf.nn.softplus(logit_x_in) + tf.nn.softplus(-logit_x_in) |
|
- tf.nn.softplus(-pre_logit_scale), |
|
[1, 2, 3]) |
|
|
|
|
|
z_out, log_diff = encoder( |
|
input_=logit_x_in, hps=hps, n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=True) |
|
if FLAGS.mode != "train": |
|
z_out, log_diff = encoder( |
|
input_=logit_x_in, hps=hps, n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=False) |
|
final_shape = [image_size, image_size, 3] |
|
prior_ll = standard_normal_ll(z_out) |
|
prior_ll = tf.reduce_sum(prior_ll, [1, 2, 3]) |
|
log_diff = tf.reduce_sum(log_diff, [1, 2, 3]) |
|
log_diff += transform_cost |
|
cost = -(prior_ll + log_diff) |
|
|
|
self.x_in = x_in |
|
self.z_out = z_out |
|
self.cost = cost = tf.reduce_mean(cost) |
|
|
|
l2_reg = sum( |
|
[tf.reduce_sum(tf.square(v)) for v in tf.trainable_variables() |
|
if ("magnitude" in v.name) or ("rescaling_scale" in v.name)]) |
|
|
|
bit_per_dim = ((cost + numpy.log(256.) * image_size * image_size * 3.) |
|
/ (image_size * image_size * 3. * numpy.log(2.))) |
|
self.bit_per_dim = bit_per_dim |
|
|
|
|
|
momentum = 1. - hps.momentum |
|
decay = 1. - hps.decay |
|
if hps.optimizer == "adam": |
|
optimizer = tf.train.AdamOptimizer( |
|
learning_rate=hps.learning_rate, |
|
beta1=momentum, beta2=decay, epsilon=1e-08, |
|
use_locking=False, name="Adam") |
|
elif hps.optimizer == "rmsprop": |
|
optimizer = tf.train.RMSPropOptimizer( |
|
learning_rate=hps.learning_rate, decay=decay, |
|
momentum=momentum, epsilon=1e-04, |
|
use_locking=False, name="RMSProp") |
|
else: |
|
optimizer = tf.train.MomentumOptimizer(hps.learning_rate, |
|
momentum=momentum) |
|
|
|
step = tf.get_variable( |
|
"global_step", [], tf.int64, |
|
tf.zeros_initializer(), |
|
trainable=False) |
|
self.step = step |
|
grads_and_vars = optimizer.compute_gradients( |
|
cost + hps.l2_coeff * l2_reg, |
|
tf.trainable_variables()) |
|
grads, vars_ = zip(*grads_and_vars) |
|
capped_grads, gradient_norm = tf.clip_by_global_norm( |
|
grads, clip_norm=hps.clip_gradient) |
|
gradient_norm = tf.check_numerics(gradient_norm, |
|
"Gradient norm is NaN or Inf.") |
|
|
|
l2_z = tf.reduce_sum(tf.square(z_out), [1, 2, 3]) |
|
if not sampling: |
|
tf.summary.scalar("negative_log_likelihood", tf.reshape(cost, [])) |
|
tf.summary.scalar("gradient_norm", tf.reshape(gradient_norm, [])) |
|
tf.summary.scalar("bit_per_dim", tf.reshape(bit_per_dim, [])) |
|
tf.summary.scalar("log_diff", tf.reshape(tf.reduce_mean(log_diff), [])) |
|
tf.summary.scalar("prior_ll", tf.reshape(tf.reduce_mean(prior_ll), [])) |
|
tf.summary.scalar( |
|
"log_diff_var", |
|
tf.reshape(tf.reduce_mean(tf.square(log_diff)) |
|
- tf.square(tf.reduce_mean(log_diff)), [])) |
|
tf.summary.scalar( |
|
"prior_ll_var", |
|
tf.reshape(tf.reduce_mean(tf.square(prior_ll)) |
|
- tf.square(tf.reduce_mean(prior_ll)), [])) |
|
tf.summary.scalar("l2_z_mean", tf.reshape(tf.reduce_mean(l2_z), [])) |
|
tf.summary.scalar( |
|
"l2_z_var", |
|
tf.reshape(tf.reduce_mean(tf.square(l2_z)) |
|
- tf.square(tf.reduce_mean(l2_z)), [])) |
|
|
|
|
|
capped_grads_and_vars = zip(capped_grads, vars_) |
|
self.train_step = optimizer.apply_gradients( |
|
capped_grads_and_vars, global_step=step) |
|
|
|
|
|
if sampling: |
|
|
|
sample = standard_normal_sample([100] + final_shape) |
|
sample, _ = decoder( |
|
input_=sample, hps=hps, n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=True) |
|
sample = tf.nn.sigmoid(sample) |
|
|
|
sample = tf.clip_by_value(sample, 0, 1) * 255. |
|
sample = tf.reshape(sample, [100, image_size, image_size, 3]) |
|
sample = tf.transpose( |
|
tf.reshape(sample, [10, image_size * 10, image_size, 3]), |
|
[0, 2, 1, 3]) |
|
sample = tf.transpose( |
|
tf.reshape(sample, [1, image_size * 10, image_size * 10, 3]), |
|
[0, 2, 1, 3]) |
|
tf.summary.image( |
|
"samples", |
|
tf.cast(sample, tf.uint8), |
|
max_outputs=1) |
|
|
|
|
|
concatenation, _ = encoder( |
|
input_=logit_x_in, hps=hps, |
|
n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=False) |
|
concatenation = tf.reshape( |
|
concatenation, |
|
[(side_shown * side_shown), image_size, image_size, 3]) |
|
concatenation = tf.transpose( |
|
tf.reshape( |
|
concatenation, |
|
[side_shown, image_size * side_shown, image_size, 3]), |
|
[0, 2, 1, 3]) |
|
concatenation = tf.transpose( |
|
tf.reshape( |
|
concatenation, |
|
[1, image_size * side_shown, image_size * side_shown, 3]), |
|
[0, 2, 1, 3]) |
|
concatenation, _ = decoder( |
|
input_=concatenation, hps=hps, n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=False) |
|
concatenation = tf.nn.sigmoid(concatenation) * 255. |
|
tf.summary.image( |
|
"concatenation", |
|
tf.cast(concatenation, tf.uint8), |
|
max_outputs=1) |
|
|
|
|
|
|
|
|
|
z_u, _ = encoder( |
|
input_=logit_x_in[:8, :, :, :], hps=hps, |
|
n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=False) |
|
u_1 = tf.reshape(z_u[0, :, :, :], [-1]) |
|
u_2 = tf.reshape(z_u[1, :, :, :], [-1]) |
|
u_3 = tf.reshape(z_u[2, :, :, :], [-1]) |
|
u_4 = tf.reshape(z_u[3, :, :, :], [-1]) |
|
u_5 = tf.reshape(z_u[4, :, :, :], [-1]) |
|
u_6 = tf.reshape(z_u[5, :, :, :], [-1]) |
|
u_7 = tf.reshape(z_u[6, :, :, :], [-1]) |
|
u_8 = tf.reshape(z_u[7, :, :, :], [-1]) |
|
|
|
|
|
manifold_side = 8 |
|
angle_1 = numpy.arange(manifold_side) * 1. / manifold_side |
|
angle_2 = numpy.arange(manifold_side) * 1. / manifold_side |
|
angle_1 *= 2. * numpy.pi |
|
angle_2 *= 2. * numpy.pi |
|
angle_1 = angle_1.astype("float32") |
|
angle_2 = angle_2.astype("float32") |
|
angle_1 = tf.reshape(angle_1, [1, -1, 1]) |
|
angle_1 += tf.zeros([manifold_side, manifold_side, 1]) |
|
angle_2 = tf.reshape(angle_2, [-1, 1, 1]) |
|
angle_2 += tf.zeros([manifold_side, manifold_side, 1]) |
|
n_angle_3 = 40 |
|
angle_3 = numpy.arange(n_angle_3) * 1. / n_angle_3 |
|
angle_3 *= 2 * numpy.pi |
|
angle_3 = angle_3.astype("float32") |
|
angle_3 = tf.reshape(angle_3, [-1, 1, 1, 1]) |
|
angle_3 += tf.zeros([n_angle_3, manifold_side, manifold_side, 1]) |
|
manifold = tf.cos(angle_1) * ( |
|
tf.cos(angle_2) * ( |
|
tf.cos(angle_3) * u_1 + tf.sin(angle_3) * u_2) |
|
+ tf.sin(angle_2) * ( |
|
tf.cos(angle_3) * u_3 + tf.sin(angle_3) * u_4)) |
|
manifold += tf.sin(angle_1) * ( |
|
tf.cos(angle_2) * ( |
|
tf.cos(angle_3) * u_5 + tf.sin(angle_3) * u_6) |
|
+ tf.sin(angle_2) * ( |
|
tf.cos(angle_3) * u_7 + tf.sin(angle_3) * u_8)) |
|
manifold = tf.reshape( |
|
manifold, |
|
[n_angle_3 * manifold_side * manifold_side] + final_shape) |
|
manifold, _ = decoder( |
|
input_=manifold, hps=hps, n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=False) |
|
manifold = tf.nn.sigmoid(manifold) |
|
|
|
manifold = tf.clip_by_value(manifold, 0, 1) * 255. |
|
manifold = tf.reshape( |
|
manifold, |
|
[n_angle_3, |
|
manifold_side * manifold_side, |
|
image_size, |
|
image_size, |
|
3]) |
|
manifold = tf.transpose( |
|
tf.reshape( |
|
manifold, |
|
[n_angle_3, manifold_side, |
|
image_size * manifold_side, image_size, 3]), [0, 1, 3, 2, 4]) |
|
manifold = tf.transpose( |
|
tf.reshape( |
|
manifold, |
|
[n_angle_3, image_size * manifold_side, |
|
image_size * manifold_side, 3]), |
|
[0, 2, 1, 3]) |
|
manifold = tf.transpose(manifold, [1, 2, 0, 3]) |
|
manifold = tf.reshape( |
|
manifold, |
|
[1, image_size * manifold_side, |
|
image_size * manifold_side, 3 * n_angle_3]) |
|
tf.summary.image( |
|
"manifold", |
|
tf.cast(manifold[:, :, :, :3], tf.uint8), |
|
max_outputs=1) |
|
|
|
|
|
z_complete, _ = encoder( |
|
input_=logit_x_in[:hps.n_scale, :, :, :], hps=hps, |
|
n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=False) |
|
z_compressed_list = [z_complete] |
|
z_noisy_list = [z_complete] |
|
z_lost = z_complete |
|
for scale_idx in xrange(hps.n_scale - 1): |
|
z_lost = squeeze_2x2_ordered(z_lost) |
|
z_lost, _ = tf.split(axis=3, num_or_size_splits=2, value=z_lost) |
|
z_compressed = z_lost |
|
z_noisy = z_lost |
|
for _ in xrange(scale_idx + 1): |
|
z_compressed = tf.concat( |
|
[z_compressed, tf.zeros_like(z_compressed)], 3) |
|
z_compressed = squeeze_2x2_ordered( |
|
z_compressed, reverse=True) |
|
z_noisy = tf.concat( |
|
[z_noisy, tf.random_normal( |
|
z_noisy.get_shape().as_list())], 3) |
|
z_noisy = squeeze_2x2_ordered(z_noisy, reverse=True) |
|
z_compressed_list.append(z_compressed) |
|
z_noisy_list.append(z_noisy) |
|
self.z_reduced = z_lost |
|
z_compressed = tf.concat(z_compressed_list, 0) |
|
z_noisy = tf.concat(z_noisy_list, 0) |
|
noisy_images, _ = decoder( |
|
input_=z_noisy, hps=hps, n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=False) |
|
compressed_images, _ = decoder( |
|
input_=z_compressed, hps=hps, n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=False) |
|
noisy_images = tf.nn.sigmoid(noisy_images) |
|
compressed_images = tf.nn.sigmoid(compressed_images) |
|
|
|
noisy_images = tf.clip_by_value(noisy_images, 0, 1) * 255. |
|
noisy_images = tf.reshape( |
|
noisy_images, |
|
[(hps.n_scale * hps.n_scale), image_size, image_size, 3]) |
|
noisy_images = tf.transpose( |
|
tf.reshape( |
|
noisy_images, |
|
[hps.n_scale, image_size * hps.n_scale, image_size, 3]), |
|
[0, 2, 1, 3]) |
|
noisy_images = tf.transpose( |
|
tf.reshape( |
|
noisy_images, |
|
[1, image_size * hps.n_scale, image_size * hps.n_scale, 3]), |
|
[0, 2, 1, 3]) |
|
tf.summary.image( |
|
"noise", |
|
tf.cast(noisy_images, tf.uint8), |
|
max_outputs=1) |
|
compressed_images = tf.clip_by_value(compressed_images, 0, 1) * 255. |
|
compressed_images = tf.reshape( |
|
compressed_images, |
|
[(hps.n_scale * hps.n_scale), image_size, image_size, 3]) |
|
compressed_images = tf.transpose( |
|
tf.reshape( |
|
compressed_images, |
|
[hps.n_scale, image_size * hps.n_scale, image_size, 3]), |
|
[0, 2, 1, 3]) |
|
compressed_images = tf.transpose( |
|
tf.reshape( |
|
compressed_images, |
|
[1, image_size * hps.n_scale, image_size * hps.n_scale, 3]), |
|
[0, 2, 1, 3]) |
|
tf.summary.image( |
|
"compression", |
|
tf.cast(compressed_images, tf.uint8), |
|
max_outputs=1) |
|
|
|
|
|
final_shape[0] *= 2 |
|
final_shape[1] *= 2 |
|
big_sample = standard_normal_sample([25] + final_shape) |
|
big_sample, _ = decoder( |
|
input_=big_sample, hps=hps, n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=True) |
|
big_sample = tf.nn.sigmoid(big_sample) |
|
|
|
big_sample = tf.clip_by_value(big_sample, 0, 1) * 255. |
|
big_sample = tf.reshape( |
|
big_sample, |
|
[25, image_size * 2, image_size * 2, 3]) |
|
big_sample = tf.transpose( |
|
tf.reshape( |
|
big_sample, |
|
[5, image_size * 10, image_size * 2, 3]), [0, 2, 1, 3]) |
|
big_sample = tf.transpose( |
|
tf.reshape( |
|
big_sample, |
|
[1, image_size * 10, image_size * 10, 3]), |
|
[0, 2, 1, 3]) |
|
tf.summary.image( |
|
"big_sample", |
|
tf.cast(big_sample, tf.uint8), |
|
max_outputs=1) |
|
|
|
|
|
final_shape[0] *= 5 |
|
final_shape[1] *= 5 |
|
extra_large = standard_normal_sample([1] + final_shape) |
|
extra_large, _ = decoder( |
|
input_=extra_large, hps=hps, n_scale=hps.n_scale, |
|
use_batch_norm=hps.use_batch_norm, weight_norm=True, |
|
train=True) |
|
extra_large = tf.nn.sigmoid(extra_large) |
|
|
|
extra_large = tf.clip_by_value(extra_large, 0, 1) * 255. |
|
tf.summary.image( |
|
"extra_large", |
|
tf.cast(extra_large, tf.uint8), |
|
max_outputs=1) |
|
|
|
def eval_epoch(self, hps): |
|
"""Evaluate bits/dim.""" |
|
n_eval_dict = { |
|
"imnet": 50000, |
|
"lsun": 300, |
|
"celeba": 19962, |
|
"svhn": 26032, |
|
} |
|
if FLAGS.eval_set_size == 0: |
|
num_examples_eval = n_eval_dict[FLAGS.dataset] |
|
else: |
|
num_examples_eval = FLAGS.eval_set_size |
|
n_epoch = num_examples_eval / hps.batch_size |
|
eval_costs = [] |
|
bar_len = 70 |
|
for epoch_idx in xrange(n_epoch): |
|
n_equal = epoch_idx * bar_len * 1. / n_epoch |
|
n_equal = numpy.ceil(n_equal) |
|
n_equal = int(n_equal) |
|
n_dash = bar_len - n_equal |
|
progress_bar = "[" + "=" * n_equal + "-" * n_dash + "]\r" |
|
print(progress_bar, end=' ') |
|
cost = self.bit_per_dim.eval() |
|
eval_costs.append(cost) |
|
print("") |
|
return float(numpy.mean(eval_costs)) |
|
|
|
|
|
def train_model(hps, logdir): |
|
"""Training.""" |
|
with tf.Graph().as_default(): |
|
with tf.device(tf.train.replica_device_setter(0)): |
|
with tf.variable_scope("model"): |
|
model = RealNVP(hps) |
|
|
|
saver = tf.train.Saver(tf.global_variables()) |
|
|
|
|
|
summary_op = tf.summary.merge_all() |
|
|
|
|
|
init = tf.global_variables_initializer() |
|
|
|
|
|
|
|
|
|
sess = tf.Session(config=tf.ConfigProto( |
|
allow_soft_placement=True, |
|
log_device_placement=True)) |
|
sess.run(init) |
|
|
|
ckpt_state = tf.train.get_checkpoint_state(logdir) |
|
if ckpt_state and ckpt_state.model_checkpoint_path: |
|
print("Loading file %s" % ckpt_state.model_checkpoint_path) |
|
saver.restore(sess, ckpt_state.model_checkpoint_path) |
|
|
|
|
|
tf.train.start_queue_runners(sess=sess) |
|
|
|
summary_writer = tf.summary.FileWriter( |
|
logdir, |
|
graph=sess.graph) |
|
|
|
local_step = 0 |
|
while True: |
|
fetches = [model.step, model.bit_per_dim, model.train_step] |
|
|
|
should_eval_summaries = local_step % 100 == 0 |
|
if should_eval_summaries: |
|
fetches += [summary_op] |
|
|
|
|
|
start_time = time.time() |
|
outputs = sess.run(fetches) |
|
global_step_val = outputs[0] |
|
loss = outputs[1] |
|
duration = time.time() - start_time |
|
assert not numpy.isnan( |
|
loss), 'Model diverged with loss = NaN' |
|
|
|
if local_step % 10 == 0: |
|
examples_per_sec = hps.batch_size / float(duration) |
|
format_str = ('%s: step %d, loss = %.2f ' |
|
'(%.1f examples/sec; %.3f ' |
|
'sec/batch)') |
|
print(format_str % (datetime.now(), global_step_val, loss, |
|
examples_per_sec, duration)) |
|
|
|
if should_eval_summaries: |
|
summary_str = outputs[-1] |
|
summary_writer.add_summary(summary_str, global_step_val) |
|
|
|
|
|
if local_step % 1000 == 0 or (local_step + 1) == FLAGS.train_steps: |
|
checkpoint_path = os.path.join(logdir, 'model.ckpt') |
|
saver.save( |
|
sess, |
|
checkpoint_path, |
|
global_step=global_step_val) |
|
|
|
if outputs[0] >= FLAGS.train_steps: |
|
break |
|
|
|
local_step += 1 |
|
|
|
|
|
def evaluate(hps, logdir, traindir, subset="valid", return_val=False): |
|
"""Evaluation.""" |
|
hps.batch_size = 100 |
|
with tf.Graph().as_default(): |
|
with tf.device("/cpu:0"): |
|
with tf.variable_scope("model") as var_scope: |
|
eval_model = RealNVP(hps) |
|
summary_writer = tf.summary.FileWriter(logdir) |
|
var_scope.reuse_variables() |
|
|
|
saver = tf.train.Saver() |
|
sess = tf.Session(config=tf.ConfigProto( |
|
allow_soft_placement=True, |
|
log_device_placement=True)) |
|
tf.train.start_queue_runners(sess) |
|
|
|
previous_global_step = 0 |
|
|
|
with sess.as_default(): |
|
while True: |
|
ckpt_state = tf.train.get_checkpoint_state(traindir) |
|
if not (ckpt_state and ckpt_state.model_checkpoint_path): |
|
print("No model to eval yet at %s" % traindir) |
|
time.sleep(30) |
|
continue |
|
print("Loading file %s" % ckpt_state.model_checkpoint_path) |
|
saver.restore(sess, ckpt_state.model_checkpoint_path) |
|
|
|
current_step = tf.train.global_step(sess, eval_model.step) |
|
if current_step == previous_global_step: |
|
print("Waiting for the checkpoint to be updated.") |
|
time.sleep(30) |
|
continue |
|
previous_global_step = current_step |
|
|
|
print("Evaluating...") |
|
bit_per_dim = eval_model.eval_epoch(hps) |
|
print("Epoch: %d, %s -> %.3f bits/dim" |
|
% (current_step, subset, bit_per_dim)) |
|
print("Writing summary...") |
|
summary = tf.Summary() |
|
summary.value.extend( |
|
[tf.Summary.Value( |
|
tag="bit_per_dim", |
|
simple_value=bit_per_dim)]) |
|
summary_writer.add_summary(summary, current_step) |
|
|
|
if return_val: |
|
return current_step, bit_per_dim |
|
|
|
|
|
def sample_from_model(hps, logdir, traindir): |
|
"""Sampling.""" |
|
hps.batch_size = 100 |
|
with tf.Graph().as_default(): |
|
with tf.device("/cpu:0"): |
|
with tf.variable_scope("model") as var_scope: |
|
eval_model = RealNVP(hps, sampling=True) |
|
summary_writer = tf.summary.FileWriter(logdir) |
|
var_scope.reuse_variables() |
|
|
|
summary_op = tf.summary.merge_all() |
|
saver = tf.train.Saver() |
|
sess = tf.Session(config=tf.ConfigProto( |
|
allow_soft_placement=True, |
|
log_device_placement=True)) |
|
coord = tf.train.Coordinator() |
|
threads = tf.train.start_queue_runners(sess=sess, coord=coord) |
|
|
|
previous_global_step = 0 |
|
|
|
initialized = False |
|
with sess.as_default(): |
|
while True: |
|
ckpt_state = tf.train.get_checkpoint_state(traindir) |
|
if not (ckpt_state and ckpt_state.model_checkpoint_path): |
|
if not initialized: |
|
print("No model to eval yet at %s" % traindir) |
|
time.sleep(30) |
|
continue |
|
else: |
|
print ("Loading file %s" |
|
% ckpt_state.model_checkpoint_path) |
|
saver.restore(sess, ckpt_state.model_checkpoint_path) |
|
|
|
current_step = tf.train.global_step(sess, eval_model.step) |
|
if current_step == previous_global_step: |
|
print("Waiting for the checkpoint to be updated.") |
|
time.sleep(30) |
|
continue |
|
previous_global_step = current_step |
|
|
|
fetches = [summary_op] |
|
|
|
outputs = sess.run(fetches) |
|
summary_writer.add_summary(outputs[0], current_step) |
|
coord.request_stop() |
|
coord.join(threads) |
|
|
|
|
|
def main(unused_argv): |
|
hps = get_default_hparams().update_config(FLAGS.hpconfig) |
|
if FLAGS.mode == "train": |
|
train_model(hps=hps, logdir=FLAGS.logdir) |
|
elif FLAGS.mode == "sample": |
|
sample_from_model(hps=hps, logdir=FLAGS.logdir, |
|
traindir=FLAGS.traindir) |
|
else: |
|
hps.batch_size = 100 |
|
evaluate(hps=hps, logdir=FLAGS.logdir, |
|
traindir=FLAGS.traindir, subset=FLAGS.mode) |
|
|
|
if __name__ == "__main__": |
|
tf.app.run() |
|
|