Spaces:
Sleeping
Sleeping
File size: 30,213 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses used for detection models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import tensorflow as tf, tf_keras
def focal_loss(logits, targets, alpha, gamma, normalizer):
"""Compute the focal loss between `logits` and the golden `target` values.
Focal loss = -(1-pt)^gamma * log(pt)
where pt is the probability of being classified to the true class.
Args:
logits: A float32 tensor of size
[batch, height_in, width_in, num_predictions].
targets: A float32 tensor of size
[batch, height_in, width_in, num_predictions].
alpha: A float32 scalar multiplying alpha to the loss from positive examples
and (1-alpha) to the loss from negative examples.
gamma: A float32 scalar modulating loss from hard and easy examples.
normalizer: A float32 scalar normalizes the total loss from all examples.
Returns:
loss: A float32 Tensor of size [batch, height_in, width_in, num_predictions]
representing normalized loss on the prediction map.
"""
with tf.name_scope('focal_loss'):
positive_label_mask = tf.math.equal(targets, 1.0)
cross_entropy = (
tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))
# Below are comments/derivations for computing modulator.
# For brevity, let x = logits, z = targets, r = gamma, and p_t = sigmod(x)
# for positive samples and 1 - sigmoid(x) for negative examples.
#
# The modulator, defined as (1 - P_t)^r, is a critical part in focal loss
# computation. For r > 0, it puts more weights on hard examples, and less
# weights on easier ones. However if it is directly computed as (1 - P_t)^r,
# its back-propagation is not stable when r < 1. The implementation here
# resolves the issue.
#
# For positive samples (labels being 1),
# (1 - p_t)^r
# = (1 - sigmoid(x))^r
# = (1 - (1 / (1 + exp(-x))))^r
# = (exp(-x) / (1 + exp(-x)))^r
# = exp(log((exp(-x) / (1 + exp(-x)))^r))
# = exp(r * log(exp(-x)) - r * log(1 + exp(-x)))
# = exp(- r * x - r * log(1 + exp(-x)))
#
# For negative samples (labels being 0),
# (1 - p_t)^r
# = (sigmoid(x))^r
# = (1 / (1 + exp(-x)))^r
# = exp(log((1 / (1 + exp(-x)))^r))
# = exp(-r * log(1 + exp(-x)))
#
# Therefore one unified form for positive (z = 1) and negative (z = 0)
# samples is:
# (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))).
neg_logits = -1.0 * logits
modulator = tf.math.exp(gamma * targets * neg_logits -
gamma * tf.math.log1p(tf.math.exp(neg_logits)))
loss = modulator * cross_entropy
weighted_loss = tf.where(positive_label_mask, alpha * loss,
(1.0 - alpha) * loss)
weighted_loss /= normalizer
return weighted_loss
class RpnScoreLoss(object):
"""Region Proposal Network score loss function."""
def __init__(self, params):
self._rpn_batch_size_per_im = params.rpn_batch_size_per_im
self._binary_crossentropy = tf_keras.losses.BinaryCrossentropy(
reduction=tf_keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, score_outputs, labels):
"""Computes total RPN detection loss.
Computes total RPN detection loss including box and score from all levels.
Args:
score_outputs: an OrderDict with keys representing levels and values
representing scores in [batch_size, height, width, num_anchors].
labels: the dictionary that returned from dataloader that includes
groundturth targets.
Returns:
rpn_score_loss: a scalar tensor representing total score loss.
"""
with tf.name_scope('rpn_loss'):
levels = sorted(score_outputs.keys())
score_losses = []
for level in levels:
score_losses.append(
self._rpn_score_loss(
score_outputs[level],
labels[level],
normalizer=tf.cast(
tf.shape(score_outputs[level])[0] *
self._rpn_batch_size_per_im, dtype=tf.float32)))
# Sums per level losses to total loss.
return tf.math.add_n(score_losses)
def _rpn_score_loss(self, score_outputs, score_targets, normalizer=1.0):
"""Computes score loss."""
# score_targets has three values:
# (1) score_targets[i]=1, the anchor is a positive sample.
# (2) score_targets[i]=0, negative.
# (3) score_targets[i]=-1, the anchor is don't care (ignore).
with tf.name_scope('rpn_score_loss'):
mask = tf.math.logical_or(tf.math.equal(score_targets, 1),
tf.math.equal(score_targets, 0))
score_targets = tf.math.maximum(score_targets,
tf.zeros_like(score_targets))
score_targets = tf.expand_dims(score_targets, axis=-1)
score_outputs = tf.expand_dims(score_outputs, axis=-1)
score_loss = self._binary_crossentropy(
score_targets, score_outputs, sample_weight=mask)
score_loss /= normalizer
return score_loss
class RpnBoxLoss(object):
"""Region Proposal Network box regression loss function."""
def __init__(self, params):
logging.info('RpnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
self._huber_loss = tf_keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf_keras.losses.Reduction.SUM)
def __call__(self, box_outputs, labels):
"""Computes total RPN detection loss.
Computes total RPN detection loss including box and score from all levels.
Args:
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in
[batch_size, height, width, num_anchors * 4].
labels: the dictionary that returned from dataloader that includes
groundturth targets.
Returns:
rpn_box_loss: a scalar tensor representing total box regression loss.
"""
with tf.name_scope('rpn_loss'):
levels = sorted(box_outputs.keys())
box_losses = []
for level in levels:
box_losses.append(self._rpn_box_loss(box_outputs[level], labels[level]))
# Sum per level losses to total loss.
return tf.add_n(box_losses)
def _rpn_box_loss(self, box_outputs, box_targets, normalizer=1.0):
"""Computes box regression loss."""
with tf.name_scope('rpn_box_loss'):
mask = tf.cast(tf.not_equal(box_targets, 0.0), dtype=tf.float32)
box_targets = tf.expand_dims(box_targets, axis=-1)
box_outputs = tf.expand_dims(box_outputs, axis=-1)
box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
# The loss is normalized by the sum of non-zero weights and additional
# normalizer provided by the function caller. Using + 0.01 here to avoid
# division by zero.
box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
return box_loss
class OlnRpnCenterLoss(object):
"""Object Localization Network RPN centerness regression loss function."""
def __init__(self):
self._l1_loss = tf_keras.losses.MeanAbsoluteError(
reduction=tf_keras.losses.Reduction.SUM)
def __call__(self, center_outputs, labels):
"""Computes total RPN centerness regression loss.
Computes total RPN centerness score regression loss from all levels.
Args:
center_outputs: an OrderDict with keys representing levels and values
representing anchor centerness regression targets in
[batch_size, height, width, num_anchors * 4].
labels: the dictionary that returned from dataloader that includes
groundturth targets.
Returns:
rpn_center_loss: a scalar tensor representing total centerness regression
loss.
"""
with tf.name_scope('rpn_loss'):
# Normalizer.
levels = sorted(center_outputs.keys())
num_valid = 0
# 0<pos<1, neg=0, ign=-1
for level in levels:
num_valid += tf.reduce_sum(tf.cast(
tf.greater(labels[level], -1.0), tf.float32)) # in and out of box
num_valid += 1e-12
# Centerness loss over multi levels.
center_losses = []
for level in levels:
center_losses.append(
self._rpn_center_l1_loss(
center_outputs[level], labels[level],
normalizer=num_valid))
# Sum per level losses to total loss.
return tf.add_n(center_losses)
def _rpn_center_l1_loss(self, center_outputs, center_targets,
normalizer=1.0):
"""Computes centerness regression loss."""
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.name_scope('rpn_center_loss'):
# mask = tf.greater(center_targets, 0.0) # inside box only.
mask = tf.greater(center_targets, -1.0) # in and out of box.
center_targets = tf.maximum(center_targets, tf.zeros_like(center_targets))
center_outputs = tf.sigmoid(center_outputs)
center_targets = tf.expand_dims(center_targets, -1)
center_outputs = tf.expand_dims(center_outputs, -1)
mask = tf.cast(mask, dtype=tf.float32)
center_loss = self._l1_loss(center_targets, center_outputs,
sample_weight=mask)
center_loss /= normalizer
return center_loss
class OlnRpnIoULoss(object):
"""Object Localization Network RPN box-lrtb regression iou loss function."""
def __call__(self, box_outputs, labels, center_targets):
"""Computes total RPN detection loss.
Computes total RPN box regression loss from all levels.
Args:
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in
[batch_size, height, width, num_anchors * 4].
last channel: (left, right, top, bottom).
labels: the dictionary that returned from dataloader that includes
groundturth targets (left, right, top, bottom).
center_targets: valid_target mask.
Returns:
rpn_iou_loss: a scalar tensor representing total box regression loss.
"""
with tf.name_scope('rpn_loss'):
# Normalizer.
levels = sorted(box_outputs.keys())
normalizer = 0.
for level in levels:
# center_targets pos>0, neg=0, ign=-1.
mask_ = tf.cast(tf.logical_and(
tf.greater(center_targets[level][..., 0], 0.0),
tf.greater(tf.reduce_min(labels[level], -1), 0.0)), tf.float32)
normalizer += tf.reduce_sum(mask_)
normalizer += 1e-8
# iou_loss over multi levels.
iou_losses = []
for level in levels:
iou_losses.append(
self._rpn_iou_loss(
box_outputs[level], labels[level],
center_weight=center_targets[level][..., 0],
normalizer=normalizer))
# Sum per level losses to total loss.
return tf.add_n(iou_losses)
def _rpn_iou_loss(self, box_outputs, box_targets,
center_weight=None, normalizer=1.0):
"""Computes box regression loss."""
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.name_scope('rpn_iou_loss'):
mask = tf.logical_and(
tf.greater(center_weight, 0.0),
tf.greater(tf.reduce_min(box_targets, -1), 0.0))
pred_left = box_outputs[..., 0]
pred_right = box_outputs[..., 1]
pred_top = box_outputs[..., 2]
pred_bottom = box_outputs[..., 3]
gt_left = box_targets[..., 0]
gt_right = box_targets[..., 1]
gt_top = box_targets[..., 2]
gt_bottom = box_targets[..., 3]
inter_width = (tf.minimum(pred_left, gt_left) +
tf.minimum(pred_right, gt_right))
inter_height = (tf.minimum(pred_top, gt_top) +
tf.minimum(pred_bottom, gt_bottom))
inter_area = inter_width * inter_height
union_area = ((pred_left + pred_right) * (pred_top + pred_bottom) +
(gt_left + gt_right) * (gt_top + gt_bottom) -
inter_area)
iou = inter_area / (union_area + 1e-8)
mask_ = tf.cast(mask, tf.float32)
iou = tf.clip_by_value(iou, clip_value_min=1e-8, clip_value_max=1.0)
neg_log_iou = -tf.math.log(iou)
iou_loss = tf.reduce_sum(neg_log_iou * mask_)
iou_loss /= normalizer
return iou_loss
class FastrcnnClassLoss(object):
"""Fast R-CNN classification loss function."""
def __init__(self):
self._categorical_crossentropy = tf_keras.losses.CategoricalCrossentropy(
reduction=tf_keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, class_outputs, class_targets):
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
This function implements the classification loss of the Fast-RCNN.
The classification loss is softmax on all RoIs.
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
Args:
class_outputs: a float tensor representing the class prediction for each box
with a shape of [batch_size, num_boxes, num_classes].
class_targets: a float tensor representing the class label for each box
with a shape of [batch_size, num_boxes].
Returns:
a scalar tensor representing total class loss.
"""
with tf.name_scope('fast_rcnn_loss'):
batch_size, num_boxes, num_classes = class_outputs.get_shape().as_list()
class_targets = tf.cast(class_targets, dtype=tf.int32)
class_targets_one_hot = tf.one_hot(class_targets, num_classes)
return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot,
normalizer=batch_size * num_boxes / 2.0)
def _fast_rcnn_class_loss(self, class_outputs, class_targets_one_hot,
normalizer):
"""Computes classification loss."""
with tf.name_scope('fast_rcnn_class_loss'):
class_loss = self._categorical_crossentropy(class_targets_one_hot,
class_outputs)
class_loss /= normalizer
return class_loss
class FastrcnnBoxLoss(object):
"""Fast R-CNN box regression loss function."""
def __init__(self, params):
logging.info('FastrcnnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
self._huber_loss = tf_keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf_keras.losses.Reduction.SUM)
def __call__(self, box_outputs, class_targets, box_targets):
"""Computes the box loss (Fast-RCNN branch) of Mask-RCNN.
This function implements the box regression loss of the Fast-RCNN. As the
`box_outputs` produces `num_classes` boxes for each RoI, the reference model
expands `box_targets` to match the shape of `box_outputs` and selects only
the target that the RoI has a maximum overlap. (Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/fast_rcnn.py) # pylint: disable=line-too-long
Instead, this function selects the `box_outputs` by the `class_targets` so
that it doesn't expand `box_targets`.
The box loss is smooth L1-loss on only positive samples of RoIs.
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
Args:
box_outputs: a float tensor representing the box prediction for each box
with a shape of [batch_size, num_boxes, num_classes * 4].
class_targets: a float tensor representing the class label for each box
with a shape of [batch_size, num_boxes].
box_targets: a float tensor representing the box label for each box
with a shape of [batch_size, num_boxes, 4].
Returns:
box_loss: a scalar tensor representing total box regression loss.
"""
with tf.name_scope('fast_rcnn_loss'):
class_targets = tf.cast(class_targets, dtype=tf.int32)
# Selects the box from `box_outputs` based on `class_targets`, with which
# the box has the maximum overlap.
(batch_size, num_rois,
num_class_specific_boxes) = box_outputs.get_shape().as_list()
num_classes = num_class_specific_boxes // 4
box_outputs = tf.reshape(box_outputs,
[batch_size, num_rois, num_classes, 4])
box_indices = tf.reshape(
class_targets + tf.tile(
tf.expand_dims(
tf.range(batch_size) * num_rois * num_classes, 1),
[1, num_rois]) + tf.tile(
tf.expand_dims(tf.range(num_rois) * num_classes, 0),
[batch_size, 1]), [-1])
box_outputs = tf.matmul(
tf.one_hot(
box_indices,
batch_size * num_rois * num_classes,
dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4]))
box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets)
def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets,
normalizer=1.0):
"""Computes box regression loss."""
with tf.name_scope('fast_rcnn_box_loss'):
mask = tf.tile(tf.expand_dims(tf.greater(class_targets, 0), axis=2),
[1, 1, 4])
mask = tf.cast(mask, dtype=tf.float32)
box_targets = tf.expand_dims(box_targets, axis=-1)
box_outputs = tf.expand_dims(box_outputs, axis=-1)
box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
# The loss is normalized by the number of ones in mask,
# additianal normalizer provided by the user and using 0.01 here to avoid
# division by 0.
box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
return box_loss
class OlnBoxScoreLoss(object):
"""Object Localization Network Box-Iou scoring function."""
def __init__(self, params):
self._ignore_threshold = params.ignore_threshold
self._l1_loss = tf_keras.losses.MeanAbsoluteError(
reduction=tf_keras.losses.Reduction.SUM)
def __call__(self, score_outputs, score_targets):
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
This function implements the classification loss of the Fast-RCNN.
The classification loss is softmax on all RoIs.
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
Args:
score_outputs: a float tensor representing the class prediction for each box
with a shape of [batch_size, num_boxes, num_classes].
score_targets: a float tensor representing the class label for each box
with a shape of [batch_size, num_boxes].
Returns:
a scalar tensor representing total score loss.
"""
with tf.name_scope('fast_rcnn_loss'):
score_outputs = tf.squeeze(score_outputs, -1)
mask = tf.greater(score_targets, self._ignore_threshold)
num_valid = tf.reduce_sum(tf.cast(mask, tf.float32))
score_targets = tf.maximum(score_targets, tf.zeros_like(score_targets))
score_outputs = tf.sigmoid(score_outputs)
score_targets = tf.expand_dims(score_targets, -1)
score_outputs = tf.expand_dims(score_outputs, -1)
mask = tf.cast(mask, dtype=tf.float32)
score_loss = self._l1_loss(score_targets, score_outputs,
sample_weight=mask)
score_loss /= (num_valid + 1e-10)
return score_loss
class MaskrcnnLoss(object):
"""Mask R-CNN instance segmentation mask loss function."""
def __init__(self):
self._binary_crossentropy = tf_keras.losses.BinaryCrossentropy(
reduction=tf_keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, mask_outputs, mask_targets, select_class_targets):
"""Computes the mask loss of Mask-RCNN.
This function implements the mask loss of Mask-RCNN. As the `mask_outputs`
produces `num_classes` masks for each RoI, the reference model expands
`mask_targets` to match the shape of `mask_outputs` and selects only the
target that the RoI has a maximum overlap. (Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/mask_rcnn.py) # pylint: disable=line-too-long
Instead, this implementation selects the `mask_outputs` by the `class_targets`
so that it doesn't expand `mask_targets`. Note that the selection logic is
done in the post-processing of mask_rcnn_fn in mask_rcnn_architecture.py.
Args:
mask_outputs: a float tensor representing the prediction for each mask,
with a shape of
[batch_size, num_masks, mask_height, mask_width].
mask_targets: a float tensor representing the binary mask of ground truth
labels for each mask with a shape of
[batch_size, num_masks, mask_height, mask_width].
select_class_targets: a tensor with a shape of [batch_size, num_masks],
representing the foreground mask targets.
Returns:
mask_loss: a float tensor representing total mask loss.
"""
with tf.name_scope('mask_rcnn_loss'):
(batch_size, num_masks, mask_height,
mask_width) = mask_outputs.get_shape().as_list()
weights = tf.tile(
tf.reshape(tf.greater(select_class_targets, 0),
[batch_size, num_masks, 1, 1]),
[1, 1, mask_height, mask_width])
weights = tf.cast(weights, dtype=tf.float32)
mask_targets = tf.expand_dims(mask_targets, axis=-1)
mask_outputs = tf.expand_dims(mask_outputs, axis=-1)
mask_loss = self._binary_crossentropy(mask_targets, mask_outputs,
sample_weight=weights)
# The loss is normalized by the number of 1's in weights and
# + 0.01 is used to avoid division by zero.
return mask_loss / (tf.reduce_sum(weights) + 0.01)
class RetinanetClassLoss(object):
"""RetinaNet class loss."""
def __init__(self, params, num_classes):
self._num_classes = num_classes
self._focal_loss_alpha = params.focal_loss_alpha
self._focal_loss_gamma = params.focal_loss_gamma
def __call__(self, cls_outputs, labels, num_positives):
"""Computes total detection loss.
Computes total detection loss including box and class loss from all levels.
Args:
cls_outputs: an OrderDict with keys representing levels and values
representing logits in [batch_size, height, width,
num_anchors * num_classes].
labels: the dictionary that returned from dataloader that includes
class groundturth targets.
num_positives: number of positive examples in the minibatch.
Returns:
an integar tensor representing total class loss.
"""
# Sums all positives in a batch for normalization and avoids zero
# num_positives_sum, which would lead to inf loss during training
num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0
cls_losses = []
for level in cls_outputs.keys():
cls_losses.append(self.class_loss(
cls_outputs[level], labels[level], num_positives_sum))
# Sums per level losses to total loss.
return tf.add_n(cls_losses)
def class_loss(self, cls_outputs, cls_targets, num_positives,
ignore_label=-2):
"""Computes RetinaNet classification loss."""
# Onehot encoding for classification labels.
cls_targets_one_hot = tf.one_hot(cls_targets, self._num_classes)
bs, height, width, _, _ = cls_targets_one_hot.get_shape().as_list()
cls_targets_one_hot = tf.reshape(cls_targets_one_hot,
[bs, height, width, -1])
loss = focal_loss(tf.cast(cls_outputs, dtype=tf.float32),
tf.cast(cls_targets_one_hot, dtype=tf.float32),
self._focal_loss_alpha,
self._focal_loss_gamma,
num_positives)
ignore_loss = tf.where(
tf.equal(cls_targets, ignore_label),
tf.zeros_like(cls_targets, dtype=tf.float32),
tf.ones_like(cls_targets, dtype=tf.float32),
)
ignore_loss = tf.expand_dims(ignore_loss, -1)
ignore_loss = tf.tile(ignore_loss, [1, 1, 1, 1, self._num_classes])
ignore_loss = tf.reshape(ignore_loss, tf.shape(input=loss))
return tf.reduce_sum(input_tensor=ignore_loss * loss)
class RetinanetBoxLoss(object):
"""RetinaNet box loss."""
def __init__(self, params):
self._huber_loss = tf_keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf_keras.losses.Reduction.SUM)
def __call__(self, box_outputs, labels, num_positives):
"""Computes box detection loss.
Computes total detection loss including box and class loss from all levels.
Args:
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in [batch_size, height, width,
num_anchors * 4].
labels: the dictionary that returned from dataloader that includes
box groundturth targets.
num_positives: number of positive examples in the minibatch.
Returns:
an integer tensor representing total box regression loss.
"""
# Sums all positives in a batch for normalization and avoids zero
# num_positives_sum, which would lead to inf loss during training
num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0
box_losses = []
for level in box_outputs.keys():
box_targets_l = labels[level]
box_losses.append(
self.box_loss(box_outputs[level], box_targets_l, num_positives_sum))
# Sums per level losses to total loss.
return tf.add_n(box_losses)
def box_loss(self, box_outputs, box_targets, num_positives):
"""Computes RetinaNet box regression loss."""
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2].
normalizer = num_positives * 4.0
mask = tf.cast(tf.not_equal(box_targets, 0.0), dtype=tf.float32)
box_targets = tf.expand_dims(box_targets, axis=-1)
box_outputs = tf.expand_dims(box_outputs, axis=-1)
box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
box_loss /= normalizer
return box_loss
class ShapemaskMseLoss(object):
"""ShapeMask mask Mean Squared Error loss function wrapper."""
def __call__(self, probs, labels, valid_mask):
"""Compute instance segmentation loss.
Args:
probs: A Tensor of shape [batch_size * num_points, height, width,
num_classes]. The logits are not necessarily between 0 and 1.
labels: A float32/float16 Tensor of shape [batch_size, num_instances,
mask_size, mask_size], where mask_size =
mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
for coarse masks and shape priors.
valid_mask: a binary mask indicating valid training masks.
Returns:
loss: an float tensor representing total mask classification loss.
"""
with tf.name_scope('shapemask_prior_loss'):
batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
diff = (tf.cast(labels, dtype=tf.float32) -
tf.cast(probs, dtype=tf.float32))
diff *= tf.cast(
tf.reshape(valid_mask, [batch_size, num_instances, 1, 1]),
tf.float32)
# Adding 0.001 in the denominator to avoid division by zero.
loss = tf.nn.l2_loss(diff) / (tf.reduce_sum(labels) + 0.001)
return loss
class ShapemaskLoss(object):
"""ShapeMask mask loss function wrapper."""
def __init__(self):
self._binary_crossentropy = tf_keras.losses.BinaryCrossentropy(
reduction=tf_keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, logits, labels, valid_mask):
"""ShapeMask mask cross entropy loss function wrapper.
Args:
logits: A Tensor of shape [batch_size * num_instances, height, width,
num_classes]. The logits are not necessarily between 0 and 1.
labels: A float16/float32 Tensor of shape [batch_size, num_instances,
mask_size, mask_size], where mask_size =
mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
for coarse masks and shape priors.
valid_mask: a binary mask of shape [batch_size, num_instances]
indicating valid training masks.
Returns:
loss: an float tensor representing total mask classification loss.
"""
with tf.name_scope('shapemask_loss'):
batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
labels = tf.cast(labels, tf.float32)
logits = tf.cast(logits, tf.float32)
loss = self._binary_crossentropy(labels, logits)
loss *= tf.cast(tf.reshape(
valid_mask, [batch_size, num_instances, 1, 1]), loss.dtype)
# Adding 0.001 in the denominator to avoid division by zero.
loss = tf.reduce_sum(loss) / (tf.reduce_sum(labels) + 0.001)
return loss
|