File size: 6,138 Bytes
97b6013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Losses for Generator and Discriminator."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


def discriminator_loss(predictions, labels, missing_tokens):
  """Discriminator loss based on predictions and labels.

  Args:
    predictions:  Discriminator linear predictions Tensor of shape [batch_size,
      sequence_length]
    labels: Labels for predictions, Tensor of shape [batch_size,
      sequence_length]
    missing_tokens:  Indicator for the missing tokens.  Evaluate the loss only
      on the tokens that were missing.

  Returns:
    loss:  Scalar tf.float32 loss.

  """
  loss = tf.losses.sigmoid_cross_entropy(labels,
                                         predictions,
                                         weights=missing_tokens)
  loss = tf.Print(
      loss, [loss, labels, missing_tokens],
      message='loss, labels, missing_tokens',
      summarize=25,
      first_n=25)
  return loss


def cross_entropy_loss_matrix(gen_labels, gen_logits):
  """Computes the cross entropy loss for G.

  Args:
    gen_labels:  Labels for the correct token.
    gen_logits: Generator logits.

  Returns:
    loss_matrix:  Loss matrix of shape [batch_size, sequence_length].
  """
  cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=gen_labels, logits=gen_logits)
  return cross_entropy_loss


def GAN_loss_matrix(dis_predictions):
  """Computes the cross entropy loss for G.

  Args:
    dis_predictions:  Discriminator predictions.

  Returns:
    loss_matrix: Loss matrix of shape [batch_size, sequence_length].
  """
  eps = tf.constant(1e-7, tf.float32)
  gan_loss_matrix = -tf.log(dis_predictions + eps)
  return gan_loss_matrix


def generator_GAN_loss(predictions):
  """Generator GAN loss based on Discriminator predictions."""
  return -tf.log(tf.reduce_mean(predictions))


def generator_blended_forward_loss(gen_logits, gen_labels, dis_predictions,
                                   is_real_input):
  """Computes the masked-loss for G.  This will be a blend of cross-entropy
  loss where the true label is known and GAN loss where the true label has been
  masked.

  Args:
    gen_logits: Generator logits.
    gen_labels:  Labels for the correct token.
    dis_predictions:  Discriminator predictions.
    is_real_input:  Tensor indicating whether the label is present.

  Returns:
    loss: Scalar tf.float32 total loss.
  """
  cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=gen_labels, logits=gen_logits)
  gan_loss = -tf.log(dis_predictions)
  loss_matrix = tf.where(is_real_input, cross_entropy_loss, gan_loss)
  return tf.reduce_mean(loss_matrix)


def wasserstein_generator_loss(gen_logits, gen_labels, dis_values,
                               is_real_input):
  """Computes the masked-loss for G.  This will be a blend of cross-entropy
  loss where the true label is known and GAN loss where the true label is
  missing.

  Args:
    gen_logits:  Generator logits.
    gen_labels:  Labels for the correct token.
    dis_values:  Discriminator values Tensor of shape [batch_size,
      sequence_length].
    is_real_input:  Tensor indicating whether the label is present.

  Returns:
    loss: Scalar tf.float32 total loss.
  """
  cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=gen_labels, logits=gen_logits)
  # Maximize the dis_values (minimize the negative)
  gan_loss = -dis_values
  loss_matrix = tf.where(is_real_input, cross_entropy_loss, gan_loss)
  loss = tf.reduce_mean(loss_matrix)
  return loss


def wasserstein_discriminator_loss(real_values, fake_values):
  """Wasserstein discriminator loss.

  Args:
    real_values: Value given by the Wasserstein Discriminator to real data.
    fake_values: Value given by the Wasserstein Discriminator to fake data.

  Returns:
    loss:  Scalar tf.float32 loss.

  """
  real_avg = tf.reduce_mean(real_values)
  fake_avg = tf.reduce_mean(fake_values)

  wasserstein_loss = real_avg - fake_avg
  return wasserstein_loss


def wasserstein_discriminator_loss_intrabatch(values, is_real_input):
  """Wasserstein discriminator loss.  This is an odd variant where the value
  difference is between the real tokens and the fake tokens within a single
  batch.

  Args:
    values: Value given by the Wasserstein Discriminator of shape [batch_size,
      sequence_length] to an imputed batch (real and fake).
    is_real_input: tf.bool Tensor of shape [batch_size, sequence_length]. If
      true, it indicates that the label is known.

  Returns:
    wasserstein_loss:  Scalar tf.float32 loss.

  """
  zero_tensor = tf.constant(0., dtype=tf.float32, shape=[])

  present = tf.cast(is_real_input, tf.float32)
  missing = tf.cast(1 - present, tf.float32)

  # Counts for real and fake tokens.
  real_count = tf.reduce_sum(present)
  fake_count = tf.reduce_sum(missing)

  # Averages for real and fake token values.
  real = tf.mul(values, present)
  fake = tf.mul(values, missing)
  real_avg = tf.reduce_sum(real) / real_count
  fake_avg = tf.reduce_sum(fake) / fake_count

  # If there are no real or fake entries in the batch, we assign an average
  # value of zero.
  real_avg = tf.where(tf.equal(real_count, 0), zero_tensor, real_avg)
  fake_avg = tf.where(tf.equal(fake_count, 0), zero_tensor, fake_avg)

  wasserstein_loss = real_avg - fake_avg
  return wasserstein_loss