# This file is based on the GauGAN by Rakshit et. al # https://keras.io/examples/generative/gaugan/ import tensorflow as tf import tensorflow_addons as tfa class SPADE(tf.keras.layers.Layer): def __init__(self, filters, epsilon=1e-5, **kwargs): super().__init__(**kwargs) self.epsilon = epsilon self.conv = tf.keras.layers.Conv2D(128, 3, padding="same", activation="relu") self.conv_gamma = tf.keras.layers.Conv2D(filters, 3, padding="same") self.conv_beta = tf.keras.layers.Conv2D(filters, 3, padding="same") def build(self, input_shape): self.resize_shape = input_shape[1:3] def call(self, input_tensor, raw_mask): mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest") x = self.conv(mask) gamma = self.conv_gamma(x) beta = self.conv_beta(x) mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True) std = tf.sqrt(var + self.epsilon) normalized = (input_tensor - mean) / std output = gamma * normalized + beta return output class ResBlock(tf.keras.layers.Layer): def __init__(self, filters, **kwargs): super().__init__(**kwargs) self.filters = filters def build(self, input_shape): input_filter = input_shape[-1] self.spade_1 = SPADE(input_filter) self.spade_2 = SPADE(self.filters) self.conv_1 = tf.keras.layers.Conv2D(self.filters, 3, padding="same") self.conv_2 = tf.keras.layers.Conv2D(self.filters, 3, padding="same") self.learned_skip = False if self.filters != input_filter: self.learned_skip = True self.spade_3 = SPADE(input_filter) self.conv_3 = tf.keras.layers.Conv2D(self.filters, 3, padding="same") def call(self, input_tensor, mask): x = self.spade_1(input_tensor, mask) x = self.conv_1(tf.nn.leaky_relu(x, 0.2)) x = self.spade_2(x, mask) x = self.conv_2(tf.nn.leaky_relu(x, 0.2)) skip = ( self.conv_3(tf.nn.leaky_relu(self.spade_3(input_tensor, mask), 0.2)) if self.learned_skip else input_tensor ) output = skip + x return output class GaussianSampler(tf.keras.layers.Layer): def __init__(self, batch_size, latent_dim, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size self.latent_dim = latent_dim def call(self, inputs): means, variance = inputs epsilon = tf.random.normal( shape=(self.batch_size, self.latent_dim), mean=0.0, stddev=1.0 ) samples = means + tf.exp(0.5 * variance) * epsilon return samples def downsample( channels, kernels, strides=2, apply_norm=True, apply_activation=True, apply_dropout=False, ): block = tf.keras.Sequential() block.add( tf.keras.layers.Conv2D( channels, kernels, strides=strides, padding="same", use_bias=False, kernel_initializer=tf.keras.initializers.GlorotNormal(), ) ) if apply_norm: block.add(tfa.layers.InstanceNormalization()) if apply_activation: block.add(tf.keras.layers.LeakyReLU(0.2)) if apply_dropout: block.add(tf.keras.layers.Dropout(0.5)) return block def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256): input_image = tf.keras.Input(shape=image_shape) x = downsample(encoder_downsample_factor, 3, apply_norm=False)(input_image) x = downsample(2 * encoder_downsample_factor, 3)(x) x = downsample(4 * encoder_downsample_factor, 3)(x) x = downsample(8 * encoder_downsample_factor, 3)(x) x = downsample(8 * encoder_downsample_factor, 3)(x) x = downsample(8 * encoder_downsample_factor, 3)(x) x = downsample(16 * encoder_downsample_factor, 3)(x) x = tf.keras.layers.Flatten()(x) mean = tf.keras.layers.Dense(latent_dim, name="mean")(x) variance = tf.keras.layers.Dense(latent_dim, name="variance")(x) return tf.keras.Model(input_image, [mean, variance], name="encoder") def build_generator(mask_shape, latent_dim=256): latent = tf.keras.Input(shape=(latent_dim)) mask = tf.keras.Input(shape=mask_shape) x = tf.keras.layers.Dense(16384)(latent) x = tf.keras.layers.Reshape((4, 4, 1024))(x) x = ResBlock(filters=1024)(x, mask) x = tf.keras.layers.UpSampling2D((2, 2))(x) x = ResBlock(filters=1024)(x, mask) x = tf.keras.layers.UpSampling2D((2, 2))(x) x = ResBlock(filters=1024)(x, mask) x = tf.keras.layers.UpSampling2D((2, 2))(x) x = ResBlock(filters=512)(x, mask) x = tf.keras.layers.UpSampling2D((2, 2))(x) x = ResBlock(filters=256)(x, mask) x = tf.keras.layers.UpSampling2D((2, 2))(x) x = ResBlock(filters=128)(x, mask) x = tf.keras.layers.UpSampling2D((2, 2))(x) x = ResBlock(filters=64)(x, mask) # These 2 added layers x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 512x512 x = ResBlock(filters=32)(x, mask) # These 2 added layers x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 1024x1024 x = tf.nn.leaky_relu(x, 0.2) output_image = tf.nn.sigmoid(tf.keras.layers.Conv2D(3, 4, padding="same")(x)) return tf.keras.Model([latent, mask], output_image, name="generator") def build_discriminator(image_shape, downsample_factor=64): input_image_A = tf.keras.Input(shape=image_shape, name="discriminator_image_A") input_image_B = tf.keras.Input(shape=image_shape, name="discriminator_image_B") x = tf.keras.layers.Concatenate()([input_image_A, input_image_B]) x1 = downsample(downsample_factor, 4, apply_norm=False)(x) x2 = downsample(2 * downsample_factor, 4)(x1) x3 = downsample(4 * downsample_factor, 4)(x2) x4 = downsample(8 * downsample_factor, 4)(x3) x5 = downsample(8 * downsample_factor, 4)(x4) x6 = downsample(8 * downsample_factor, 4)(x5) x7 = downsample(16 * downsample_factor, 4)(x6) x8 = tf.keras.layers.Conv2D(1, 4)(x7) outputs = [x1, x2, x3, x4, x5, x6, x7, x8] return tf.keras.Model([input_image_A, input_image_B], outputs) def generator_loss(y): return -tf.reduce_mean(y) def kl_divergence_loss(mean, variance): return -0.5 * tf.reduce_sum(1 + variance - tf.square(mean) - tf.exp(variance)) class FeatureMatchingLoss(tf.keras.losses.Loss): def __init__(self, **kwargs): super().__init__(**kwargs) self.mae = tf.keras.losses.MeanAbsoluteError() def call(self, y_true, y_pred): loss = 0 for i in range(len(y_true) - 1): loss += self.mae(y_true[i], y_pred[i]) return loss class VGGFeatureMatchingLoss(tf.keras.losses.Loss): def __init__(self, **kwargs): super().__init__(**kwargs) self.encoder_layers = [ "block1_conv1", "block2_conv1", "block3_conv1", "block4_conv1", "block5_conv1", ] self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] vgg = tf.keras.applications.VGG19(include_top=False, weights="imagenet") layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers] self.vgg_model = tf.keras.Model(vgg.input, layer_outputs, name="VGG") self.mae = tf.keras.losses.MeanAbsoluteError() def call(self, y_true, y_pred): y_true = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1)) y_pred = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1)) real_features = self.vgg_model(y_true) fake_features = self.vgg_model(y_pred) loss = 0 for i in range(len(real_features)): loss += self.weights[i] * self.mae(real_features[i], fake_features[i]) return loss class DiscriminatorLoss(tf.keras.losses.Loss): def __init__(self, **kwargs): super().__init__(**kwargs) self.hinge_loss = tf.keras.losses.Hinge() def call(self, y, is_real): label = 1.0 if is_real else -1.0 return self.hinge_loss(label, y) class GauGAN(tf.keras.Model): def __init__( self, image_size, num_classes, batch_size, latent_dim, feature_loss_coeff=10, vgg_feature_loss_coeff=0.1, kl_divergence_loss_coeff=0.1, **kwargs, ): super().__init__(**kwargs) self.image_size = image_size self.latent_dim = latent_dim self.batch_size = batch_size self.num_classes = num_classes self.image_shape = (image_size, image_size, 3) self.mask_shape = (image_size, image_size, num_classes) self.feature_loss_coeff = feature_loss_coeff self.vgg_feature_loss_coeff = vgg_feature_loss_coeff self.kl_divergence_loss_coeff = kl_divergence_loss_coeff self.discriminator = build_discriminator(self.image_shape) self.generator = build_generator(self.mask_shape, latent_dim=latent_dim) self.encoder = build_encoder(self.image_shape, latent_dim=latent_dim) self.sampler = GaussianSampler(batch_size, latent_dim) self.patch_size, self.combined_model = self.build_combined_generator() self.disc_loss_tracker = tf.keras.metrics.Mean(name="disc_loss") self.gen_loss_tracker = tf.keras.metrics.Mean(name="gen_loss") self.feat_loss_tracker = tf.keras.metrics.Mean(name="feat_loss") self.vgg_loss_tracker = tf.keras.metrics.Mean(name="vgg_loss") self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") @property def metrics(self): return [ self.disc_loss_tracker, self.gen_loss_tracker, self.feat_loss_tracker, self.vgg_loss_tracker, self.kl_loss_tracker, ] def build_combined_generator(self): # This method builds a model that takes as inputs the following: # latent vector, one-hot encoded segmentation label map, and # a segmentation map. It then (i) generates an image with the generator, # (ii) passes the generated images and segmentation map to the discriminator. # Finally, the model produces the following outputs: (a) discriminator outputs, # (b) generated image. # We will be using this model to simplify the implementation. self.discriminator.trainable = False mask_input = tf.keras.Input(shape=self.mask_shape, name="mask") image_input = tf.keras.Input(shape=self.image_shape, name="image") latent_input = tf.keras.Input(shape=(self.latent_dim), name="latent") generated_image = self.generator([latent_input, mask_input]) discriminator_output = self.discriminator([image_input, generated_image]) patch_size = discriminator_output[-1].shape[1] combined_model = tf.keras.Model( [latent_input, mask_input, image_input], [discriminator_output, generated_image], ) return patch_size, combined_model def compile(self, gen_lr=1e-4, disc_lr=4e-4, **kwargs): super().compile(**kwargs) self.generator_optimizer = tf.keras.optimizers.Adam( gen_lr, beta_1=0.0, beta_2=0.999 ) self.discriminator_optimizer = tf.keras.optimizers.Adam( disc_lr, beta_1=0.0, beta_2=0.999 ) self.discriminator_loss = DiscriminatorLoss() self.feature_matching_loss = FeatureMatchingLoss() self.vgg_loss = VGGFeatureMatchingLoss() def train_discriminator(self, latent_vector, segmentation_map, real_image, labels): fake_images = self.generator([latent_vector, labels]) with tf.GradientTape() as gradient_tape: pred_fake = self.discriminator([segmentation_map, fake_images])[-1] pred_real = self.discriminator([segmentation_map, real_image])[-1] loss_fake = self.discriminator_loss(pred_fake, False) loss_real = self.discriminator_loss(pred_real, True) total_loss = 0.5 * (loss_fake + loss_real) self.discriminator.trainable = True gradients = gradient_tape.gradient( total_loss, self.discriminator.trainable_variables ) self.discriminator_optimizer.apply_gradients( zip(gradients, self.discriminator.trainable_variables) ) return total_loss def train_generator( self, latent_vector, segmentation_map, labels, image, mean, variance ): # Generator learns through the signal provided by the discriminator. During # backpropagation, we only update the generator parameters. self.discriminator.trainable = False with tf.GradientTape() as tape: real_d_output = self.discriminator([segmentation_map, image]) fake_d_output, fake_image = self.combined_model( [latent_vector, labels, segmentation_map] ) pred = fake_d_output[-1] # Compute generator losses. g_loss = generator_loss(pred) kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance) vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image) feature_loss = self.feature_loss_coeff * self.feature_matching_loss(real_d_output, fake_d_output) total_loss = g_loss + kl_loss + vgg_loss + feature_loss gradients = tape.gradient(total_loss, self.combined_model.trainable_variables) self.generator_optimizer.apply_gradients( zip(gradients, self.combined_model.trainable_variables) ) return total_loss, feature_loss, vgg_loss, kl_loss def train_step(self, data): segmentation_map, image, labels = data mean, variance = self.encoder(image) latent_vector = self.sampler([mean, variance]) discriminator_loss = self.train_discriminator( latent_vector, segmentation_map, image, labels ) (generator_loss, feature_loss, vgg_loss, kl_loss) = self.train_generator( latent_vector, segmentation_map, labels, image, mean, variance ) # Report progress. self.disc_loss_tracker.update_state(discriminator_loss) self.gen_loss_tracker.update_state(generator_loss) self.feat_loss_tracker.update_state(feature_loss) self.vgg_loss_tracker.update_state(vgg_loss) self.kl_loss_tracker.update_state(kl_loss) results = {m.name: m.result() for m in self.metrics} return results def test_step(self, data): segmentation_map, image, labels = data # Obtain the learned moments of the real image distribution. mean, variance = self.encoder(image) # Sample a latent from the distribution defined by the learned moments. latent_vector = self.sampler([mean, variance]) # Generate the fake images. fake_images = self.generator([latent_vector, labels]) # Calculate the losses. pred_fake = self.discriminator([segmentation_map, fake_images])[-1] pred_real = self.discriminator([segmentation_map, image])[-1] loss_fake = self.discriminator_loss(pred_fake, False) loss_real = self.discriminator_loss(pred_real, True) total_discriminator_loss = 0.5 * (loss_fake + loss_real) real_d_output = self.discriminator([segmentation_map, image]) fake_d_output, fake_image = self.combined_model( [latent_vector, labels, segmentation_map] ) pred = fake_d_output[-1] g_loss = generator_loss(pred) kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance) vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image) feature_loss = self.feature_loss_coeff * self.feature_matching_loss( real_d_output, fake_d_output ) total_generator_loss = g_loss + kl_loss + vgg_loss + feature_loss # Report progress. self.disc_loss_tracker.update_state(total_discriminator_loss) self.gen_loss_tracker.update_state(total_generator_loss) self.feat_loss_tracker.update_state(feature_loss) self.vgg_loss_tracker.update_state(vgg_loss) self.kl_loss_tracker.update_state(kl_loss) results = {m.name: m.result() for m in self.metrics} return results def call(self, inputs): latent_vectors, labels = inputs return self.generator([latent_vectors, labels])