RetinaGAN / models /gaugan.py
farrell236's picture
Upload 37 files
2aa6515
raw
history blame
16.6 kB
# This file is based on the GauGAN by Rakshit et. al
# https://keras.io/examples/generative/gaugan/
import tensorflow as tf
import tensorflow_addons as tfa
class SPADE(tf.keras.layers.Layer):
def __init__(self, filters, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
self.conv = tf.keras.layers.Conv2D(128, 3, padding="same", activation="relu")
self.conv_gamma = tf.keras.layers.Conv2D(filters, 3, padding="same")
self.conv_beta = tf.keras.layers.Conv2D(filters, 3, padding="same")
def build(self, input_shape):
self.resize_shape = input_shape[1:3]
def call(self, input_tensor, raw_mask):
mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest")
x = self.conv(mask)
gamma = self.conv_gamma(x)
beta = self.conv_beta(x)
mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
std = tf.sqrt(var + self.epsilon)
normalized = (input_tensor - mean) / std
output = gamma * normalized + beta
return output
class ResBlock(tf.keras.layers.Layer):
def __init__(self, filters, **kwargs):
super().__init__(**kwargs)
self.filters = filters
def build(self, input_shape):
input_filter = input_shape[-1]
self.spade_1 = SPADE(input_filter)
self.spade_2 = SPADE(self.filters)
self.conv_1 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")
self.conv_2 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")
self.learned_skip = False
if self.filters != input_filter:
self.learned_skip = True
self.spade_3 = SPADE(input_filter)
self.conv_3 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")
def call(self, input_tensor, mask):
x = self.spade_1(input_tensor, mask)
x = self.conv_1(tf.nn.leaky_relu(x, 0.2))
x = self.spade_2(x, mask)
x = self.conv_2(tf.nn.leaky_relu(x, 0.2))
skip = (
self.conv_3(tf.nn.leaky_relu(self.spade_3(input_tensor, mask), 0.2))
if self.learned_skip
else input_tensor
)
output = skip + x
return output
class GaussianSampler(tf.keras.layers.Layer):
def __init__(self, batch_size, latent_dim, **kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size
self.latent_dim = latent_dim
def call(self, inputs):
means, variance = inputs
epsilon = tf.random.normal(
shape=(self.batch_size, self.latent_dim), mean=0.0, stddev=1.0
)
samples = means + tf.exp(0.5 * variance) * epsilon
return samples
def downsample(
channels,
kernels,
strides=2,
apply_norm=True,
apply_activation=True,
apply_dropout=False,
):
block = tf.keras.Sequential()
block.add(
tf.keras.layers.Conv2D(
channels,
kernels,
strides=strides,
padding="same",
use_bias=False,
kernel_initializer=tf.keras.initializers.GlorotNormal(),
)
)
if apply_norm:
block.add(tfa.layers.InstanceNormalization())
if apply_activation:
block.add(tf.keras.layers.LeakyReLU(0.2))
if apply_dropout:
block.add(tf.keras.layers.Dropout(0.5))
return block
def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256):
input_image = tf.keras.Input(shape=image_shape)
x = downsample(encoder_downsample_factor, 3, apply_norm=False)(input_image)
x = downsample(2 * encoder_downsample_factor, 3)(x)
x = downsample(4 * encoder_downsample_factor, 3)(x)
x = downsample(8 * encoder_downsample_factor, 3)(x)
x = downsample(8 * encoder_downsample_factor, 3)(x)
x = downsample(8 * encoder_downsample_factor, 3)(x)
x = downsample(16 * encoder_downsample_factor, 3)(x)
x = tf.keras.layers.Flatten()(x)
mean = tf.keras.layers.Dense(latent_dim, name="mean")(x)
variance = tf.keras.layers.Dense(latent_dim, name="variance")(x)
return tf.keras.Model(input_image, [mean, variance], name="encoder")
def build_generator(mask_shape, latent_dim=256):
latent = tf.keras.Input(shape=(latent_dim))
mask = tf.keras.Input(shape=mask_shape)
x = tf.keras.layers.Dense(16384)(latent)
x = tf.keras.layers.Reshape((4, 4, 1024))(x)
x = ResBlock(filters=1024)(x, mask)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=1024)(x, mask)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=1024)(x, mask)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=512)(x, mask)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=256)(x, mask)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=128)(x, mask)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=64)(x, mask) # These 2 added layers
x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 512x512
x = ResBlock(filters=32)(x, mask) # These 2 added layers
x = tf.keras.layers.UpSampling2D((2, 2))(x) # to make input 1024x1024
x = tf.nn.leaky_relu(x, 0.2)
output_image = tf.nn.sigmoid(tf.keras.layers.Conv2D(3, 4, padding="same")(x))
return tf.keras.Model([latent, mask], output_image, name="generator")
def build_discriminator(image_shape, downsample_factor=64):
input_image_A = tf.keras.Input(shape=image_shape, name="discriminator_image_A")
input_image_B = tf.keras.Input(shape=image_shape, name="discriminator_image_B")
x = tf.keras.layers.Concatenate()([input_image_A, input_image_B])
x1 = downsample(downsample_factor, 4, apply_norm=False)(x)
x2 = downsample(2 * downsample_factor, 4)(x1)
x3 = downsample(4 * downsample_factor, 4)(x2)
x4 = downsample(8 * downsample_factor, 4)(x3)
x5 = downsample(8 * downsample_factor, 4)(x4)
x6 = downsample(8 * downsample_factor, 4)(x5)
x7 = downsample(16 * downsample_factor, 4)(x6)
x8 = tf.keras.layers.Conv2D(1, 4)(x7)
outputs = [x1, x2, x3, x4, x5, x6, x7, x8]
return tf.keras.Model([input_image_A, input_image_B], outputs)
def generator_loss(y):
return -tf.reduce_mean(y)
def kl_divergence_loss(mean, variance):
return -0.5 * tf.reduce_sum(1 + variance - tf.square(mean) - tf.exp(variance))
class FeatureMatchingLoss(tf.keras.losses.Loss):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mae = tf.keras.losses.MeanAbsoluteError()
def call(self, y_true, y_pred):
loss = 0
for i in range(len(y_true) - 1):
loss += self.mae(y_true[i], y_pred[i])
return loss
class VGGFeatureMatchingLoss(tf.keras.losses.Loss):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.encoder_layers = [
"block1_conv1",
"block2_conv1",
"block3_conv1",
"block4_conv1",
"block5_conv1",
]
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
vgg = tf.keras.applications.VGG19(include_top=False, weights="imagenet")
layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers]
self.vgg_model = tf.keras.Model(vgg.input, layer_outputs, name="VGG")
self.mae = tf.keras.losses.MeanAbsoluteError()
def call(self, y_true, y_pred):
y_true = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1))
y_pred = tf.keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1))
real_features = self.vgg_model(y_true)
fake_features = self.vgg_model(y_pred)
loss = 0
for i in range(len(real_features)):
loss += self.weights[i] * self.mae(real_features[i], fake_features[i])
return loss
class DiscriminatorLoss(tf.keras.losses.Loss):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.hinge_loss = tf.keras.losses.Hinge()
def call(self, y, is_real):
label = 1.0 if is_real else -1.0
return self.hinge_loss(label, y)
class GauGAN(tf.keras.Model):
def __init__(
self,
image_size,
num_classes,
batch_size,
latent_dim,
feature_loss_coeff=10,
vgg_feature_loss_coeff=0.1,
kl_divergence_loss_coeff=0.1,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.latent_dim = latent_dim
self.batch_size = batch_size
self.num_classes = num_classes
self.image_shape = (image_size, image_size, 3)
self.mask_shape = (image_size, image_size, num_classes)
self.feature_loss_coeff = feature_loss_coeff
self.vgg_feature_loss_coeff = vgg_feature_loss_coeff
self.kl_divergence_loss_coeff = kl_divergence_loss_coeff
self.discriminator = build_discriminator(self.image_shape)
self.generator = build_generator(self.mask_shape, latent_dim=latent_dim)
self.encoder = build_encoder(self.image_shape, latent_dim=latent_dim)
self.sampler = GaussianSampler(batch_size, latent_dim)
self.patch_size, self.combined_model = self.build_combined_generator()
self.disc_loss_tracker = tf.keras.metrics.Mean(name="disc_loss")
self.gen_loss_tracker = tf.keras.metrics.Mean(name="gen_loss")
self.feat_loss_tracker = tf.keras.metrics.Mean(name="feat_loss")
self.vgg_loss_tracker = tf.keras.metrics.Mean(name="vgg_loss")
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.disc_loss_tracker,
self.gen_loss_tracker,
self.feat_loss_tracker,
self.vgg_loss_tracker,
self.kl_loss_tracker,
]
def build_combined_generator(self):
# This method builds a model that takes as inputs the following:
# latent vector, one-hot encoded segmentation label map, and
# a segmentation map. It then (i) generates an image with the generator,
# (ii) passes the generated images and segmentation map to the discriminator.
# Finally, the model produces the following outputs: (a) discriminator outputs,
# (b) generated image.
# We will be using this model to simplify the implementation.
self.discriminator.trainable = False
mask_input = tf.keras.Input(shape=self.mask_shape, name="mask")
image_input = tf.keras.Input(shape=self.image_shape, name="image")
latent_input = tf.keras.Input(shape=(self.latent_dim), name="latent")
generated_image = self.generator([latent_input, mask_input])
discriminator_output = self.discriminator([image_input, generated_image])
patch_size = discriminator_output[-1].shape[1]
combined_model = tf.keras.Model(
[latent_input, mask_input, image_input],
[discriminator_output, generated_image],
)
return patch_size, combined_model
def compile(self, gen_lr=1e-4, disc_lr=4e-4, **kwargs):
super().compile(**kwargs)
self.generator_optimizer = tf.keras.optimizers.Adam(
gen_lr, beta_1=0.0, beta_2=0.999
)
self.discriminator_optimizer = tf.keras.optimizers.Adam(
disc_lr, beta_1=0.0, beta_2=0.999
)
self.discriminator_loss = DiscriminatorLoss()
self.feature_matching_loss = FeatureMatchingLoss()
self.vgg_loss = VGGFeatureMatchingLoss()
def train_discriminator(self, latent_vector, segmentation_map, real_image, labels):
fake_images = self.generator([latent_vector, labels])
with tf.GradientTape() as gradient_tape:
pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
pred_real = self.discriminator([segmentation_map, real_image])[-1]
loss_fake = self.discriminator_loss(pred_fake, False)
loss_real = self.discriminator_loss(pred_real, True)
total_loss = 0.5 * (loss_fake + loss_real)
self.discriminator.trainable = True
gradients = gradient_tape.gradient(
total_loss, self.discriminator.trainable_variables
)
self.discriminator_optimizer.apply_gradients(
zip(gradients, self.discriminator.trainable_variables)
)
return total_loss
def train_generator(
self, latent_vector, segmentation_map, labels, image, mean, variance
):
# Generator learns through the signal provided by the discriminator. During
# backpropagation, we only update the generator parameters.
self.discriminator.trainable = False
with tf.GradientTape() as tape:
real_d_output = self.discriminator([segmentation_map, image])
fake_d_output, fake_image = self.combined_model(
[latent_vector, labels, segmentation_map]
)
pred = fake_d_output[-1]
# Compute generator losses.
g_loss = generator_loss(pred)
kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
feature_loss = self.feature_loss_coeff * self.feature_matching_loss(real_d_output, fake_d_output)
total_loss = g_loss + kl_loss + vgg_loss + feature_loss
gradients = tape.gradient(total_loss, self.combined_model.trainable_variables)
self.generator_optimizer.apply_gradients(
zip(gradients, self.combined_model.trainable_variables)
)
return total_loss, feature_loss, vgg_loss, kl_loss
def train_step(self, data):
segmentation_map, image, labels = data
mean, variance = self.encoder(image)
latent_vector = self.sampler([mean, variance])
discriminator_loss = self.train_discriminator(
latent_vector, segmentation_map, image, labels
)
(generator_loss, feature_loss, vgg_loss, kl_loss) = self.train_generator(
latent_vector, segmentation_map, labels, image, mean, variance
)
# Report progress.
self.disc_loss_tracker.update_state(discriminator_loss)
self.gen_loss_tracker.update_state(generator_loss)
self.feat_loss_tracker.update_state(feature_loss)
self.vgg_loss_tracker.update_state(vgg_loss)
self.kl_loss_tracker.update_state(kl_loss)
results = {m.name: m.result() for m in self.metrics}
return results
def test_step(self, data):
segmentation_map, image, labels = data
# Obtain the learned moments of the real image distribution.
mean, variance = self.encoder(image)
# Sample a latent from the distribution defined by the learned moments.
latent_vector = self.sampler([mean, variance])
# Generate the fake images.
fake_images = self.generator([latent_vector, labels])
# Calculate the losses.
pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
pred_real = self.discriminator([segmentation_map, image])[-1]
loss_fake = self.discriminator_loss(pred_fake, False)
loss_real = self.discriminator_loss(pred_real, True)
total_discriminator_loss = 0.5 * (loss_fake + loss_real)
real_d_output = self.discriminator([segmentation_map, image])
fake_d_output, fake_image = self.combined_model(
[latent_vector, labels, segmentation_map]
)
pred = fake_d_output[-1]
g_loss = generator_loss(pred)
kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
real_d_output, fake_d_output
)
total_generator_loss = g_loss + kl_loss + vgg_loss + feature_loss
# Report progress.
self.disc_loss_tracker.update_state(total_discriminator_loss)
self.gen_loss_tracker.update_state(total_generator_loss)
self.feat_loss_tracker.update_state(feature_loss)
self.vgg_loss_tracker.update_state(vgg_loss)
self.kl_loss_tracker.update_state(kl_loss)
results = {m.name: m.result() for m in self.metrics}
return results
def call(self, inputs):
latent_vectors, labels = inputs
return self.generator([latent_vectors, labels])