UGATIT / ugatit /ops.py
hylee's picture
init
7a7f105
import tensorflow as tf
import tensorflow.contrib as tf_contrib
# Xavier : tf_contrib.layers.xavier_initializer()
# He : tf_contrib.layers.variance_scaling_initializer()
# Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
# l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001)
##################################################################################
# Layer
##################################################################################
def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
with tf.variable_scope(scope):
if pad > 0 :
if (kernel - stride) % 2 == 0:
pad_top = pad
pad_bottom = pad
pad_left = pad
pad_right = pad
else:
pad_top = pad
pad_bottom = kernel - stride - pad_top
pad_left = pad
pad_right = kernel - stride - pad_left
if pad_type == 'zero':
x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
if pad_type == 'reflect':
x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
if sn :
w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
regularizer=weight_regularizer)
x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
strides=[1, stride, stride, 1], padding='VALID')
if use_bias :
bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
x = tf.nn.bias_add(x, bias)
else :
x = tf.layers.conv2d(inputs=x, filters=channels,
kernel_size=kernel, kernel_initializer=weight_init,
kernel_regularizer=weight_regularizer,
strides=stride, use_bias=use_bias)
return x
def fully_connected_with_w(x, use_bias=True, sn=False, reuse=False, scope='linear'):
with tf.variable_scope(scope, reuse=reuse):
x = flatten(x)
bias = 0.0
shape = x.get_shape().as_list()
channels = shape[-1]
w = tf.get_variable("kernel", [channels, 1], tf.float32,
initializer=weight_init, regularizer=weight_regularizer)
if sn :
w = spectral_norm(w)
if use_bias :
bias = tf.get_variable("bias", [1],
initializer=tf.constant_initializer(0.0))
x = tf.matmul(x, w) + bias
else :
x = tf.matmul(x, w)
if use_bias :
weights = tf.gather(tf.transpose(tf.nn.bias_add(w, bias)), 0)
else :
weights = tf.gather(tf.transpose(w), 0)
return x, weights
def fully_connected(x, units, use_bias=True, sn=False, scope='linear'):
with tf.variable_scope(scope):
x = flatten(x)
shape = x.get_shape().as_list()
channels = shape[-1]
if sn:
w = tf.get_variable("kernel", [channels, units], tf.float32,
initializer=weight_init, regularizer=weight_regularizer)
if use_bias:
bias = tf.get_variable("bias", [units],
initializer=tf.constant_initializer(0.0))
x = tf.matmul(x, spectral_norm(w)) + bias
else:
x = tf.matmul(x, spectral_norm(w))
else :
x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
return x
def flatten(x) :
return tf.layers.flatten(x)
##################################################################################
# Residual-block
##################################################################################
def resblock(x_init, channels, use_bias=True, scope='resblock_0'):
with tf.variable_scope(scope):
with tf.variable_scope('res1'):
x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
x = instance_norm(x)
x = relu(x)
with tf.variable_scope('res2'):
x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
x = instance_norm(x)
return x + x_init
def adaptive_ins_layer_resblock(x_init, channels, gamma, beta, use_bias=True, smoothing=True, scope='adaptive_resblock') :
with tf.variable_scope(scope):
with tf.variable_scope('res1'):
x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
x = adaptive_instance_layer_norm(x, gamma, beta, smoothing)
x = relu(x)
with tf.variable_scope('res2'):
x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
x = adaptive_instance_layer_norm(x, gamma, beta, smoothing)
return x + x_init
##################################################################################
# Sampling
##################################################################################
def up_sample(x, scale_factor=2):
_, h, w, _ = x.get_shape().as_list()
new_size = [h * scale_factor, w * scale_factor]
return tf.image.resize_nearest_neighbor(x, size=new_size)
def global_avg_pooling(x):
gap = tf.reduce_mean(x, axis=[1, 2])
return gap
def global_max_pooling(x):
gmp = tf.reduce_max(x, axis=[1, 2])
return gmp
##################################################################################
# Activation function
##################################################################################
def lrelu(x, alpha=0.01):
# pytorch alpha is 0.01
return tf.nn.leaky_relu(x, alpha)
def relu(x):
return tf.nn.relu(x)
def tanh(x):
return tf.tanh(x)
def sigmoid(x) :
return tf.sigmoid(x)
##################################################################################
# Normalization function
##################################################################################
def adaptive_instance_layer_norm(x, gamma, beta, smoothing=True, scope='instance_layer_norm') :
with tf.variable_scope(scope):
ch = x.shape[-1]
eps = 1e-5
ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
if smoothing :
rho = tf.clip_by_value(rho - tf.constant(0.1), 0.0, 1.0)
x_hat = rho * x_ins + (1 - rho) * x_ln
x_hat = x_hat * gamma + beta
return x_hat
def instance_norm(x, scope='instance_norm'):
return tf_contrib.layers.instance_norm(x,
epsilon=1e-05,
center=True, scale=True,
scope=scope)
def layer_norm(x, scope='layer_norm') :
return tf_contrib.layers.layer_norm(x,
center=True, scale=True,
scope=scope)
def layer_instance_norm(x, scope='layer_instance_norm') :
with tf.variable_scope(scope):
ch = x.shape[-1]
eps = 1e-5
ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(0.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0))
beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0))
x_hat = rho * x_ins + (1 - rho) * x_ln
x_hat = x_hat * gamma + beta
return x_hat
def spectral_norm(w, iteration=1):
w_shape = w.shape.as_list()
w = tf.reshape(w, [-1, w_shape[-1]])
u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
u_hat = u
v_hat = None
for i in range(iteration):
"""
power iteration
Usually iteration = 1 will be enough
"""
v_ = tf.matmul(u_hat, tf.transpose(w))
v_hat = tf.nn.l2_normalize(v_)
u_ = tf.matmul(v_hat, w)
u_hat = tf.nn.l2_normalize(u_)
u_hat = tf.stop_gradient(u_hat)
v_hat = tf.stop_gradient(v_hat)
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
with tf.control_dependencies([u.assign(u_hat)]):
w_norm = w / sigma
w_norm = tf.reshape(w_norm, w_shape)
return w_norm
##################################################################################
# Loss function
##################################################################################
def L1_loss(x, y):
loss = tf.reduce_mean(tf.abs(x - y))
return loss
def cam_loss(source, non_source) :
identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(source), logits=source))
non_identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(non_source), logits=non_source))
loss = identity_loss + non_identity_loss
return loss
def regularization_loss(scope_name) :
"""
If you want to use "Regularization"
g_loss += regularization_loss('generator')
d_loss += regularization_loss('discriminator')
"""
collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = []
for item in collection_regularization :
if scope_name in item.name :
loss.append(item)
return tf.reduce_sum(loss)
def discriminator_loss(loss_func, real, fake):
loss = []
real_loss = 0
fake_loss = 0
for i in range(2) :
if loss_func.__contains__('wgan') :
real_loss = -tf.reduce_mean(real[i])
fake_loss = tf.reduce_mean(fake[i])
if loss_func == 'lsgan' :
real_loss = tf.reduce_mean(tf.squared_difference(real[i], 1.0))
fake_loss = tf.reduce_mean(tf.square(fake[i]))
if loss_func == 'gan' or loss_func == 'dragan' :
real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real[i]), logits=real[i]))
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake[i]), logits=fake[i]))
if loss_func == 'hinge' :
real_loss = tf.reduce_mean(relu(1.0 - real[i]))
fake_loss = tf.reduce_mean(relu(1.0 + fake[i]))
loss.append(real_loss + fake_loss)
return sum(loss)
def generator_loss(loss_func, fake):
loss = []
fake_loss = 0
for i in range(2) :
if loss_func.__contains__('wgan') :
fake_loss = -tf.reduce_mean(fake[i])
if loss_func == 'lsgan' :
fake_loss = tf.reduce_mean(tf.squared_difference(fake[i], 1.0))
if loss_func == 'gan' or loss_func == 'dragan' :
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake[i]), logits=fake[i]))
if loss_func == 'hinge' :
fake_loss = -tf.reduce_mean(fake[i])
loss.append(fake_loss)
return sum(loss)