|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Losses for Generator and Discriminator.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import tensorflow as tf |
|
|
|
|
|
def discriminator_loss(predictions, labels, missing_tokens): |
|
"""Discriminator loss based on predictions and labels. |
|
|
|
Args: |
|
predictions: Discriminator linear predictions Tensor of shape [batch_size, |
|
sequence_length] |
|
labels: Labels for predictions, Tensor of shape [batch_size, |
|
sequence_length] |
|
missing_tokens: Indicator for the missing tokens. Evaluate the loss only |
|
on the tokens that were missing. |
|
|
|
Returns: |
|
loss: Scalar tf.float32 loss. |
|
|
|
""" |
|
loss = tf.losses.sigmoid_cross_entropy(labels, |
|
predictions, |
|
weights=missing_tokens) |
|
loss = tf.Print( |
|
loss, [loss, labels, missing_tokens], |
|
message='loss, labels, missing_tokens', |
|
summarize=25, |
|
first_n=25) |
|
return loss |
|
|
|
|
|
def cross_entropy_loss_matrix(gen_labels, gen_logits): |
|
"""Computes the cross entropy loss for G. |
|
|
|
Args: |
|
gen_labels: Labels for the correct token. |
|
gen_logits: Generator logits. |
|
|
|
Returns: |
|
loss_matrix: Loss matrix of shape [batch_size, sequence_length]. |
|
""" |
|
cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=gen_labels, logits=gen_logits) |
|
return cross_entropy_loss |
|
|
|
|
|
def GAN_loss_matrix(dis_predictions): |
|
"""Computes the cross entropy loss for G. |
|
|
|
Args: |
|
dis_predictions: Discriminator predictions. |
|
|
|
Returns: |
|
loss_matrix: Loss matrix of shape [batch_size, sequence_length]. |
|
""" |
|
eps = tf.constant(1e-7, tf.float32) |
|
gan_loss_matrix = -tf.log(dis_predictions + eps) |
|
return gan_loss_matrix |
|
|
|
|
|
def generator_GAN_loss(predictions): |
|
"""Generator GAN loss based on Discriminator predictions.""" |
|
return -tf.log(tf.reduce_mean(predictions)) |
|
|
|
|
|
def generator_blended_forward_loss(gen_logits, gen_labels, dis_predictions, |
|
is_real_input): |
|
"""Computes the masked-loss for G. This will be a blend of cross-entropy |
|
loss where the true label is known and GAN loss where the true label has been |
|
masked. |
|
|
|
Args: |
|
gen_logits: Generator logits. |
|
gen_labels: Labels for the correct token. |
|
dis_predictions: Discriminator predictions. |
|
is_real_input: Tensor indicating whether the label is present. |
|
|
|
Returns: |
|
loss: Scalar tf.float32 total loss. |
|
""" |
|
cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=gen_labels, logits=gen_logits) |
|
gan_loss = -tf.log(dis_predictions) |
|
loss_matrix = tf.where(is_real_input, cross_entropy_loss, gan_loss) |
|
return tf.reduce_mean(loss_matrix) |
|
|
|
|
|
def wasserstein_generator_loss(gen_logits, gen_labels, dis_values, |
|
is_real_input): |
|
"""Computes the masked-loss for G. This will be a blend of cross-entropy |
|
loss where the true label is known and GAN loss where the true label is |
|
missing. |
|
|
|
Args: |
|
gen_logits: Generator logits. |
|
gen_labels: Labels for the correct token. |
|
dis_values: Discriminator values Tensor of shape [batch_size, |
|
sequence_length]. |
|
is_real_input: Tensor indicating whether the label is present. |
|
|
|
Returns: |
|
loss: Scalar tf.float32 total loss. |
|
""" |
|
cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=gen_labels, logits=gen_logits) |
|
|
|
gan_loss = -dis_values |
|
loss_matrix = tf.where(is_real_input, cross_entropy_loss, gan_loss) |
|
loss = tf.reduce_mean(loss_matrix) |
|
return loss |
|
|
|
|
|
def wasserstein_discriminator_loss(real_values, fake_values): |
|
"""Wasserstein discriminator loss. |
|
|
|
Args: |
|
real_values: Value given by the Wasserstein Discriminator to real data. |
|
fake_values: Value given by the Wasserstein Discriminator to fake data. |
|
|
|
Returns: |
|
loss: Scalar tf.float32 loss. |
|
|
|
""" |
|
real_avg = tf.reduce_mean(real_values) |
|
fake_avg = tf.reduce_mean(fake_values) |
|
|
|
wasserstein_loss = real_avg - fake_avg |
|
return wasserstein_loss |
|
|
|
|
|
def wasserstein_discriminator_loss_intrabatch(values, is_real_input): |
|
"""Wasserstein discriminator loss. This is an odd variant where the value |
|
difference is between the real tokens and the fake tokens within a single |
|
batch. |
|
|
|
Args: |
|
values: Value given by the Wasserstein Discriminator of shape [batch_size, |
|
sequence_length] to an imputed batch (real and fake). |
|
is_real_input: tf.bool Tensor of shape [batch_size, sequence_length]. If |
|
true, it indicates that the label is known. |
|
|
|
Returns: |
|
wasserstein_loss: Scalar tf.float32 loss. |
|
|
|
""" |
|
zero_tensor = tf.constant(0., dtype=tf.float32, shape=[]) |
|
|
|
present = tf.cast(is_real_input, tf.float32) |
|
missing = tf.cast(1 - present, tf.float32) |
|
|
|
|
|
real_count = tf.reduce_sum(present) |
|
fake_count = tf.reduce_sum(missing) |
|
|
|
|
|
real = tf.mul(values, present) |
|
fake = tf.mul(values, missing) |
|
real_avg = tf.reduce_sum(real) / real_count |
|
fake_avg = tf.reduce_sum(fake) / fake_count |
|
|
|
|
|
|
|
real_avg = tf.where(tf.equal(real_count, 0), zero_tensor, real_avg) |
|
fake_avg = tf.where(tf.equal(fake_count, 0), zero_tensor, fake_avg) |
|
|
|
wasserstein_loss = real_avg - fake_avg |
|
return wasserstein_loss |
|
|