Spaces:
Sleeping
Sleeping
# This file is based on the GauGAN by Rakshit et. al | |
# https://keras.io/examples/generative/gaugan/ | |
import tensorflow as tf | |
import tensorflow_addons as tfa | |
class SPADE(tf.keras.layers.Layer): | |
def __init__(self, filters, epsilon=1e-5, **kwargs): | |
super().__init__(**kwargs) | |
self.epsilon = epsilon | |
self.conv = tf.keras.layers.Conv2D(128, 3, padding="same", activation="relu") | |
self.conv_gamma = tf.keras.layers.Conv2D(filters, 3, padding="same") | |
self.conv_beta = tf.keras.layers.Conv2D(filters, 3, padding="same") | |
def build(self, input_shape): | |
self.resize_shape = input_shape[1:3] | |
def call(self, input_tensor, raw_mask): | |
mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest") | |
x = self.conv(mask) | |
gamma = self.conv_gamma(x) | |
beta = self.conv_beta(x) | |
mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True) | |
std = tf.sqrt(var + self.epsilon) | |
normalized = (input_tensor - mean) / std | |
output = gamma * normalized + beta | |
return output | |
class ResBlock(tf.keras.layers.Layer): | |
def __init__(self, filters, **kwargs): | |
super().__init__(**kwargs) | |
self.filters = filters | |
def build(self, input_shape): | |
input_filter = input_shape[-1] | |
self.spade_1 = SPADE(input_filter) | |
self.spade_2 = SPADE(self.filters) | |
self.conv_1 = tf.keras.layers.Conv2D(self.filters, 3, padding="same") | |
self.conv_2 = tf.keras.layers.Conv2D(self.filters, 3, padding="same") | |
self.learned_skip = False | |
if self.filters != input_filter: | |
self.learned_skip = True | |
self.spade_3 = SPADE(input_filter) | |
self.conv_3 = tf.keras.layers.Conv2D(self.filters, 3, padding="same") | |
def call(self, input_tensor, mask): | |
x = self.spade_1(input_tensor, mask) | |
x = self.conv_1(tf.nn.leaky_relu(x, 0.2)) | |
x = self.spade_2(x, mask) | |
x = self.conv_2(tf.nn.leaky_relu(x, 0.2)) | |
skip = ( | |
self.conv_3(tf.nn.leaky_relu(self.spade_3(input_tensor, mask), 0.2)) | |
if self.learned_skip | |
else input_tensor | |
) | |
output = skip + x | |
return output | |
class GaussianSampler(tf.keras.layers.Layer): | |
def __init__(self, batch_size, latent_dim, **kwargs): | |
super().__init__(**kwargs) | |
self.batch_size = batch_size | |
self.latent_dim = latent_dim | |
def call(self, inputs): | |
means, variance = inputs | |
epsilon = tf.random.normal( | |
shape=(self.batch_size, self.latent_dim), mean=0.0, stddev=1.0 | |
) | |
samples = means + tf.exp(0.5 * variance) * epsilon | |
return samples | |
def downsample( | |
channels, | |
kernels, | |
strides=2, | |
apply_norm=True, | |
apply_activation=True, | |
apply_dropout=False, | |
): | |
block = tf.keras.Sequential() | |
block.add( | |
tf.keras.layers.Conv2D( | |
channels, | |
kernels, | |
strides=strides, | |
padding="same", | |
use_bias=False, | |
kernel_initializer=tf.keras.initializers.GlorotNormal(), | |
) | |
) | |
if apply_norm: | |
block.add(tfa.layers.InstanceNormalization()) | |
if apply_activation: | |
block.add(tf.keras.layers.LeakyReLU(0.2)) | |
if apply_dropout: | |
block.add(tf.keras.layers.Dropout(0.5)) | |
return block | |
def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256): | |
input_image = tf.keras.Input(shape=image_shape) | |
x = downsample(encoder_downsample_factor, 3, apply_norm=False)(input_image) | |
x = downsample(2 * encoder_downsample_factor, 3)(x) | |
x = downsample(4 * encoder_downsample_factor, 3)(x) | |
x = downsample(8 * encoder_downsample_factor, 3)(x) | |
x = downsample(8 * encoder_downsample_factor, 3)(x) | |
x = downsample(8 * encoder_downsample_factor, 3)(x) | |
x = downsample(16 * encoder_downsample_factor, 3)(x) | |
x = tf.keras.layers.Flatten()(x) | |
mean = tf.keras.layers.Dense(latent_dim, name="mean")(x) | |
variance = tf.keras.layers.Dense(latent_dim, name="variance")(x) | |
return tf.keras.Model(input_image, [mean, variance], name="encoder") | |
def build_generator(mask_shape, latent_dim=256): | |
latent = tf.keras.Input(shape=(latent_dim)) | |
mask = tf.keras.Input(shape=mask_shape) | |
x = tf.keras.layers.Dense(16384)(latent) | |
x = tf.keras.layers.Reshape((4, 4, 1024))(x) | |
x = ResBlock(filters=1024)(x, mask) | |
x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
x = ResBlock(filters=1024)(x, mask) | |
x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
x = ResBlock(filters=1024)(x, mask) | |
x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
x = ResBlock(filters=512)(x, mask) | |
x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
x = ResBlock(filters=256)(x, mask) | |
x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
x = ResBlock(filters=128)(x, mask) | |
x = tf.keras.layers.UpSampling2D((2, 2))(x) | |
x = ResBlock(filters=64)(x, mask) # These 2 added layers | |
x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 512x512 | |
x = ResBlock(filters=32)(x, mask) # These 2 added layers | |
x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 1024x1024 | |
x = tf.nn.leaky_relu(x, 0.2) | |
output_image = tf.nn.sigmoid(tf.keras.layers.Conv2D(3, 4, padding="same")(x)) | |
return tf.keras.Model([latent, mask], output_image, name="generator") | |
def build_discriminator(image_shape, downsample_factor=64): | |
input_image_A = tf.keras.Input(shape=image_shape, name="discriminator_image_A") | |
input_image_B = tf.keras.Input(shape=image_shape, name="discriminator_image_B") | |
x = tf.keras.layers.Concatenate()([input_image_A, input_image_B]) | |
x1 = downsample(downsample_factor, 4, apply_norm=False)(x) | |
x2 = downsample(2 * downsample_factor, 4)(x1) | |
x3 = downsample(4 * downsample_factor, 4)(x2) | |
x4 = downsample(8 * downsample_factor, 4)(x3) | |
x5 = downsample(8 * downsample_factor, 4)(x4) | |
x6 = downsample(8 * downsample_factor, 4)(x5) | |
x7 = downsample(16 * downsample_factor, 4)(x6) | |
x8 = tf.keras.layers.Conv2D(1, 4)(x7) | |
outputs = [x1, x2, x3, x4, x5, x6, x7, x8] | |
return tf.keras.Model([input_image_A, input_image_B], outputs) | |
def generator_loss(y): | |
return -tf.reduce_mean(y) | |
def kl_divergence_loss(mean, variance): | |
return -0.5 * tf.reduce_sum(1 + variance - tf.square(mean) - tf.exp(variance)) | |
class FeatureMatchingLoss(tf.keras.losses.Loss): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.mae = tf.keras.losses.MeanAbsoluteError() | |
def call(self, y_true, y_pred): | |
loss = 0 | |
for i in range(len(y_true) - 1): | |
loss += self.mae(y_true[i], y_pred[i]) | |
return loss | |
class VGGFeatureMatchingLoss(tf.keras.losses.Loss): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.encoder_layers = [ | |
"block1_conv1", | |
"block2_conv1", | |
"block3_conv1", | |
"block4_conv1", | |
"block5_conv1", | |
] | |
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
vgg = tf.keras.applications.VGG19(include_top=False, weights="imagenet") | |
layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers] | |
self.vgg_model = tf.keras.Model(vgg.input, layer_outputs, name="VGG") | |
self.mae = tf.keras.losses.MeanAbsoluteError() | |
def call(self, y_true, y_pred): | |
y_true = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1)) | |
y_pred = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1)) | |
real_features = self.vgg_model(y_true) | |
fake_features = self.vgg_model(y_pred) | |
loss = 0 | |
for i in range(len(real_features)): | |
loss += self.weights[i] * self.mae(real_features[i], fake_features[i]) | |
return loss | |
class DiscriminatorLoss(tf.keras.losses.Loss): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.hinge_loss = tf.keras.losses.Hinge() | |
def call(self, y, is_real): | |
label = 1.0 if is_real else -1.0 | |
return self.hinge_loss(label, y) | |
class GauGAN(tf.keras.Model): | |
def __init__( | |
self, | |
image_size, | |
num_classes, | |
batch_size, | |
latent_dim, | |
feature_loss_coeff=10, | |
vgg_feature_loss_coeff=0.1, | |
kl_divergence_loss_coeff=0.1, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.image_size = image_size | |
self.latent_dim = latent_dim | |
self.batch_size = batch_size | |
self.num_classes = num_classes | |
self.image_shape = (image_size, image_size, 3) | |
self.mask_shape = (image_size, image_size, num_classes) | |
self.feature_loss_coeff = feature_loss_coeff | |
self.vgg_feature_loss_coeff = vgg_feature_loss_coeff | |
self.kl_divergence_loss_coeff = kl_divergence_loss_coeff | |
self.discriminator = build_discriminator(self.image_shape) | |
self.generator = build_generator(self.mask_shape, latent_dim=latent_dim) | |
self.encoder = build_encoder(self.image_shape, latent_dim=latent_dim) | |
self.sampler = GaussianSampler(batch_size, latent_dim) | |
self.patch_size, self.combined_model = self.build_combined_generator() | |
self.disc_loss_tracker = tf.keras.metrics.Mean(name="disc_loss") | |
self.gen_loss_tracker = tf.keras.metrics.Mean(name="gen_loss") | |
self.feat_loss_tracker = tf.keras.metrics.Mean(name="feat_loss") | |
self.vgg_loss_tracker = tf.keras.metrics.Mean(name="vgg_loss") | |
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") | |
def metrics(self): | |
return [ | |
self.disc_loss_tracker, | |
self.gen_loss_tracker, | |
self.feat_loss_tracker, | |
self.vgg_loss_tracker, | |
self.kl_loss_tracker, | |
] | |
def build_combined_generator(self): | |
# This method builds a model that takes as inputs the following: | |
# latent vector, one-hot encoded segmentation label map, and | |
# a segmentation map. It then (i) generates an image with the generator, | |
# (ii) passes the generated images and segmentation map to the discriminator. | |
# Finally, the model produces the following outputs: (a) discriminator outputs, | |
# (b) generated image. | |
# We will be using this model to simplify the implementation. | |
self.discriminator.trainable = False | |
mask_input = tf.keras.Input(shape=self.mask_shape, name="mask") | |
image_input = tf.keras.Input(shape=self.image_shape, name="image") | |
latent_input = tf.keras.Input(shape=(self.latent_dim), name="latent") | |
generated_image = self.generator([latent_input, mask_input]) | |
discriminator_output = self.discriminator([image_input, generated_image]) | |
patch_size = discriminator_output[-1].shape[1] | |
combined_model = tf.keras.Model( | |
[latent_input, mask_input, image_input], | |
[discriminator_output, generated_image], | |
) | |
return patch_size, combined_model | |
def compile(self, gen_lr=1e-4, disc_lr=4e-4, **kwargs): | |
super().compile(**kwargs) | |
self.generator_optimizer = tf.keras.optimizers.Adam( | |
gen_lr, beta_1=0.0, beta_2=0.999 | |
) | |
self.discriminator_optimizer = tf.keras.optimizers.Adam( | |
disc_lr, beta_1=0.0, beta_2=0.999 | |
) | |
self.discriminator_loss = DiscriminatorLoss() | |
self.feature_matching_loss = FeatureMatchingLoss() | |
self.vgg_loss = VGGFeatureMatchingLoss() | |
def train_discriminator(self, latent_vector, segmentation_map, real_image, labels): | |
fake_images = self.generator([latent_vector, labels]) | |
with tf.GradientTape() as gradient_tape: | |
pred_fake = self.discriminator([segmentation_map, fake_images])[-1] | |
pred_real = self.discriminator([segmentation_map, real_image])[-1] | |
loss_fake = self.discriminator_loss(pred_fake, False) | |
loss_real = self.discriminator_loss(pred_real, True) | |
total_loss = 0.5 * (loss_fake + loss_real) | |
self.discriminator.trainable = True | |
gradients = gradient_tape.gradient( | |
total_loss, self.discriminator.trainable_variables | |
) | |
self.discriminator_optimizer.apply_gradients( | |
zip(gradients, self.discriminator.trainable_variables) | |
) | |
return total_loss | |
def train_generator( | |
self, latent_vector, segmentation_map, labels, image, mean, variance | |
): | |
# Generator learns through the signal provided by the discriminator. During | |
# backpropagation, we only update the generator parameters. | |
self.discriminator.trainable = False | |
with tf.GradientTape() as tape: | |
real_d_output = self.discriminator([segmentation_map, image]) | |
fake_d_output, fake_image = self.combined_model( | |
[latent_vector, labels, segmentation_map] | |
) | |
pred = fake_d_output[-1] | |
# Compute generator losses. | |
g_loss = generator_loss(pred) | |
kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance) | |
vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image) | |
feature_loss = self.feature_loss_coeff * self.feature_matching_loss(real_d_output, fake_d_output) | |
total_loss = g_loss + kl_loss + vgg_loss + feature_loss | |
gradients = tape.gradient(total_loss, self.combined_model.trainable_variables) | |
self.generator_optimizer.apply_gradients( | |
zip(gradients, self.combined_model.trainable_variables) | |
) | |
return total_loss, feature_loss, vgg_loss, kl_loss | |
def train_step(self, data): | |
segmentation_map, image, labels = data | |
mean, variance = self.encoder(image) | |
latent_vector = self.sampler([mean, variance]) | |
discriminator_loss = self.train_discriminator( | |
latent_vector, segmentation_map, image, labels | |
) | |
(generator_loss, feature_loss, vgg_loss, kl_loss) = self.train_generator( | |
latent_vector, segmentation_map, labels, image, mean, variance | |
) | |
# Report progress. | |
self.disc_loss_tracker.update_state(discriminator_loss) | |
self.gen_loss_tracker.update_state(generator_loss) | |
self.feat_loss_tracker.update_state(feature_loss) | |
self.vgg_loss_tracker.update_state(vgg_loss) | |
self.kl_loss_tracker.update_state(kl_loss) | |
results = {m.name: m.result() for m in self.metrics} | |
return results | |
def test_step(self, data): | |
segmentation_map, image, labels = data | |
# Obtain the learned moments of the real image distribution. | |
mean, variance = self.encoder(image) | |
# Sample a latent from the distribution defined by the learned moments. | |
latent_vector = self.sampler([mean, variance]) | |
# Generate the fake images. | |
fake_images = self.generator([latent_vector, labels]) | |
# Calculate the losses. | |
pred_fake = self.discriminator([segmentation_map, fake_images])[-1] | |
pred_real = self.discriminator([segmentation_map, image])[-1] | |
loss_fake = self.discriminator_loss(pred_fake, False) | |
loss_real = self.discriminator_loss(pred_real, True) | |
total_discriminator_loss = 0.5 * (loss_fake + loss_real) | |
real_d_output = self.discriminator([segmentation_map, image]) | |
fake_d_output, fake_image = self.combined_model( | |
[latent_vector, labels, segmentation_map] | |
) | |
pred = fake_d_output[-1] | |
g_loss = generator_loss(pred) | |
kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance) | |
vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image) | |
feature_loss = self.feature_loss_coeff * self.feature_matching_loss( | |
real_d_output, fake_d_output | |
) | |
total_generator_loss = g_loss + kl_loss + vgg_loss + feature_loss | |
# Report progress. | |
self.disc_loss_tracker.update_state(total_discriminator_loss) | |
self.gen_loss_tracker.update_state(total_generator_loss) | |
self.feat_loss_tracker.update_state(feature_loss) | |
self.vgg_loss_tracker.update_state(vgg_loss) | |
self.kl_loss_tracker.update_state(kl_loss) | |
results = {m.name: m.result() for m in self.metrics} | |
return results | |
def call(self, inputs): | |
latent_vectors, labels = inputs | |
return self.generator([latent_vectors, labels]) | |