File size: 11,424 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Losses used for segmentation models."""

import tensorflow as tf, tf_keras

from official.modeling import tf_utils
from official.vision.dataloaders import utils

EPSILON = 1e-5


class SegmentationLoss:
  """Semantic segmentation loss."""

  def __init__(self,
               label_smoothing,
               class_weights,
               ignore_label,
               use_groundtruth_dimension,
               use_binary_cross_entropy=False,
               top_k_percent_pixels=1.0,
               gt_is_matting_map=False):
    """Initializes `SegmentationLoss`.

    Args:
      label_smoothing: A float, if > 0., smooth out one-hot probability by
        spreading the amount of probability to all other label classes.
      class_weights: A float list containing the weight of each class.
      ignore_label: An integer specifying the ignore label.
      use_groundtruth_dimension: A boolean, whether to resize the output to
        match the dimension of the ground truth.
      use_binary_cross_entropy: A boolean, if true, use binary cross entropy
        loss, otherwise, use categorical cross entropy.
      top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
        value < 1., only compute the loss for the top k percent pixels. This is
        useful for hard pixel mining.
      gt_is_matting_map: If or not the groundtruth mask is a matting map. Note
        that the matting map is only supported for 2 class segmentation.
    """
    self._label_smoothing = label_smoothing
    self._class_weights = class_weights
    self._ignore_label = ignore_label
    self._use_groundtruth_dimension = use_groundtruth_dimension
    self._use_binary_cross_entropy = use_binary_cross_entropy
    self._top_k_percent_pixels = top_k_percent_pixels
    self._gt_is_matting_map = gt_is_matting_map

  def __call__(self, logits, labels, **kwargs):
    """Computes `SegmentationLoss`.

    Args:
      logits: A float tensor in shape (batch_size, height, width, num_classes)
        which is the output of the network.
      labels: A tensor in shape (batch_size, height, width, num_layers), which
        is the label masks of the ground truth. The num_layers can be > 1 if the
        pixels are labeled as multiple classes.
      **kwargs: additional keyword arguments.

    Returns:
       A 0-D float which stores the overall loss of the batch.
    """
    _, height, width, num_classes = logits.get_shape().as_list()
    output_dtype = logits.dtype
    num_layers = labels.get_shape().as_list()[-1]
    if not self._use_binary_cross_entropy:
      if num_layers > 1:
        raise ValueError(
            'Groundtruth mask must have only 1 layer if using categorical'
            'cross entropy, but got {} layers.'.format(num_layers))
    if self._gt_is_matting_map:
      if num_classes != 2:
        raise ValueError(
            'Groundtruth matting map only supports 2 classes, but got {} '
            'classes.'.format(num_classes))
      if num_layers > 1:
        raise ValueError(
            'Groundtruth matting map must have only 1 layer, but got {} '
            'layers.'.format(num_layers))

    class_weights = (
        self._class_weights if self._class_weights else [1] * num_classes)
    if num_classes != len(class_weights):
      raise ValueError(
          'Length of class_weights should be {}'.format(num_classes))
    class_weights = tf.constant(class_weights, dtype=output_dtype)

    if not self._gt_is_matting_map:
      labels = tf.cast(labels, tf.int32)
    if self._use_groundtruth_dimension:
      # TODO(arashwan): Test using align corners to match deeplab alignment.
      logits = tf.image.resize(
          logits, tf.shape(labels)[1:3], method=tf.image.ResizeMethod.BILINEAR)
    else:
      labels = tf.image.resize(
          labels, (height, width),
          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    valid_mask = tf.not_equal(tf.cast(labels, tf.int32), self._ignore_label)

    # (batch_size, height, width, num_classes)
    labels_with_prob = self.get_labels_with_prob(logits, labels, valid_mask,
                                                 **kwargs)

    # (batch_size, height, width)
    valid_mask = tf.cast(tf.reduce_any(valid_mask, axis=-1), dtype=output_dtype)

    if self._use_binary_cross_entropy:
      # (batch_size, height, width, num_classes)
      cross_entropy_loss = tf.nn.sigmoid_cross_entropy_with_logits(
          labels=labels_with_prob, logits=logits)
      # (batch_size, height, width, num_classes)
      cross_entropy_loss *= class_weights
      num_valid_values = tf.reduce_sum(valid_mask) * tf.cast(
          num_classes, output_dtype)
      # (batch_size, height, width, num_classes)
      cross_entropy_loss *= valid_mask[..., tf.newaxis]
    else:
      # (batch_size, height, width)
      cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(
          labels=labels_with_prob, logits=logits)

      # If groundtruth is matting map, binarize the value to create the weight
      # mask
      if self._gt_is_matting_map:
        labels = utils.binarize_matting_map(labels)

      # (batch_size, height, width)
      weight_mask = tf.einsum(
          '...y,y->...',
          tf.one_hot(
              tf.cast(tf.squeeze(labels, axis=-1), tf.int32),
              depth=num_classes,
              dtype=output_dtype), class_weights)
      cross_entropy_loss *= weight_mask
      num_valid_values = tf.reduce_sum(valid_mask)
      cross_entropy_loss *= valid_mask

    if self._top_k_percent_pixels < 1.0:
      return self.aggregate_loss_top_k(cross_entropy_loss, num_valid_values)
    else:
      return tf.reduce_sum(cross_entropy_loss) / (num_valid_values + EPSILON)

  def get_labels_with_prob(self, logits, labels, valid_mask, **unused_kwargs):
    """Get a tensor representing the probability of each class for each pixel.

    This method can be overridden in subclasses for customizing loss function.

    Args:
      logits: A float tensor in shape (batch_size, height, width, num_classes)
        which is the output of the network.
      labels: A tensor in shape (batch_size, height, width, num_layers), which
        is the label masks of the ground truth. The num_layers can be > 1 if the
        pixels are labeled as multiple classes.
      valid_mask: A bool tensor in shape (batch_size, height, width, num_layers)
        which indicates the ignored labels in each ground truth layer.
      **unused_kwargs: Unused keyword arguments.

    Returns:
       A float tensor in shape (batch_size, height, width, num_classes).
    """
    num_classes = logits.get_shape().as_list()[-1]

    if self._gt_is_matting_map:
      # (batch_size, height, width, num_classes=2)
      train_labels = tf.concat([1 - labels, labels], axis=-1)
    else:
      labels = tf.cast(labels, tf.int32)
      # Assign pixel with ignore label to class -1, which will be ignored by
      # tf.one_hot operation.
      # (batch_size, height, width, num_masks)
      labels = tf.where(valid_mask, labels, -tf.ones_like(labels))

      if self._use_binary_cross_entropy:
        # (batch_size, height, width, num_masks, num_classes)
        one_hot_labels_per_mask = tf.one_hot(
            labels,
            depth=num_classes,
            on_value=True,
            off_value=False,
            dtype=tf.bool,
            axis=-1)
        # Aggregate all one-hot labels to get a binary mask in shape
        # (batch_size, height, width, num_classes), which represents all the
        # classes that a pixel is labeled as.
        # For example, if a pixel is labeled as "window" (id=1) and also being a
        # part of the "building" (id=3), then its train_labels are [0,1,0,1].
        train_labels = tf.cast(
            tf.reduce_any(one_hot_labels_per_mask, axis=-2), dtype=logits.dtype)
      else:
        # (batch_size, height, width, num_classes)
        train_labels = tf.one_hot(
            tf.squeeze(labels, axis=-1), depth=num_classes, dtype=logits.dtype)

    return train_labels * (
        1 - self._label_smoothing) + self._label_smoothing / num_classes

  def aggregate_loss_top_k(self, pixelwise_loss, num_valid_pixels=None):
    """Aggregate the top-k greatest pixelwise loss.

    Args:
      pixelwise_loss: a float tensor in shape (batch_size, height, width) which
        stores the loss of each pixel.
      num_valid_pixels: the number of pixels which are not ignored. If None, all
        the pixels are valid.

    Returns:
       A 0-D float which stores the overall loss of the batch.
    """
    pixelwise_loss = tf.reshape(pixelwise_loss, shape=[-1])
    top_k_pixels = tf.cast(
        self._top_k_percent_pixels
        * tf.cast(tf.size(pixelwise_loss), tf.float32),
        tf.int32,
    )
    top_k_losses, _ = tf.math.top_k(pixelwise_loss, k=top_k_pixels)
    normalizer = tf.cast(top_k_pixels, top_k_losses.dtype)
    if num_valid_pixels is not None:
      normalizer = tf.minimum(normalizer,
                              tf.cast(num_valid_pixels, top_k_losses.dtype))
    return tf.reduce_sum(top_k_losses) / (normalizer + EPSILON)


def get_actual_mask_scores(logits, labels, ignore_label):
  """Gets actual mask scores."""
  _, height, width, num_classes = logits.get_shape().as_list()
  batch_size = tf.shape(logits)[0]
  logits = tf.stop_gradient(logits)
  labels = tf.image.resize(
      labels, (height, width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  predicted_labels = tf.argmax(logits, -1, output_type=tf.int32)
  flat_predictions = tf.reshape(predicted_labels, [batch_size, -1])
  flat_labels = tf.cast(tf.reshape(labels, [batch_size, -1]), tf.int32)

  one_hot_predictions = tf.one_hot(
      flat_predictions, num_classes, on_value=True, off_value=False)
  one_hot_labels = tf.one_hot(
      flat_labels, num_classes, on_value=True, off_value=False)
  keep_mask = tf.not_equal(flat_labels, ignore_label)
  keep_mask = tf.expand_dims(keep_mask, 2)

  overlap = tf.logical_and(one_hot_predictions, one_hot_labels)
  overlap = tf.logical_and(overlap, keep_mask)
  overlap = tf.reduce_sum(tf.cast(overlap, tf.float32), axis=1)
  union = tf.logical_or(one_hot_predictions, one_hot_labels)
  union = tf.logical_and(union, keep_mask)
  union = tf.reduce_sum(tf.cast(union, tf.float32), axis=1)
  actual_scores = tf.divide(overlap, tf.maximum(union, EPSILON))
  return actual_scores


class MaskScoringLoss:
  """Mask Scoring loss."""

  def __init__(self, ignore_label):
    self._ignore_label = ignore_label
    self._mse_loss = tf_keras.losses.MeanSquaredError(
        reduction=tf_keras.losses.Reduction.NONE)

  def __call__(self, predicted_scores, logits, labels):
    actual_scores = get_actual_mask_scores(logits, labels, self._ignore_label)
    loss = tf_utils.safe_mean(self._mse_loss(actual_scores, predicted_scores))
    return loss