Spaces:
Build error
Build error
File size: 5,752 Bytes
6250360 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
This file contains specific functions for computing losses on the RPN
file
"""
import torch
from torch.nn import functional as F
from maskrcnn_benchmark.config import cfg
from .utils import concat_box_prediction_layers
from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler
from ..utils import cat
from maskrcnn_benchmark.layers import smooth_l1_loss
from maskrcnn_benchmark.layers import iou_regress
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
class RPNLossComputation(object):
"""
This class computes the RPN loss.
"""
def __init__(self, proposal_matcher, fg_bg_sampler, box_coder,
generate_labels_func):
"""
Arguments:
proposal_matcher (Matcher)
fg_bg_sampler (BalancedPositiveNegativeSampler)
box_coder (BoxCoder)
"""
self.proposal_matcher = proposal_matcher
self.fg_bg_sampler = fg_bg_sampler
self.box_coder = box_coder
self.copied_fields = []
self.generate_labels_func = generate_labels_func
self.discard_cases = ['not_visibility', 'between_thresholds']
def match_targets_to_anchors(self, anchor, target, copied_fields=[]):
match_quality_matrix = boxlist_iou(target, anchor)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# RPN doesn't need any fields from target for creating the labels, so clear them all
target = target.copy_with_fields(copied_fields)
# get the targets corresponding GT for each anchor
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_targets = target[matched_idxs.clamp(min=0)]
matched_targets.add_field("matched_idxs", matched_idxs)
return matched_targets
def prepare_targets(self, anchors, targets):
labels = []
regression_targets = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
matched_targets = self.match_targets_to_anchors(
anchors_per_image, targets_per_image, self.copied_fields
)
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = self.generate_labels_func(matched_targets)
labels_per_image = labels_per_image.to(dtype=torch.float32)
# Background (negative examples)
bg_indices = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0
# discard anchors that go out of the boundaries of the image
if "not_visibility" in self.discard_cases:
labels_per_image[~anchors_per_image.get_field("visibility")] = -1
# discard indices that are between thresholds
if "between_thresholds" in self.discard_cases:
inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1
regression_targets_per_image = matched_targets.bbox
labels.append(labels_per_image)
regression_targets.append(regression_targets_per_image)
return labels, regression_targets
def __call__(self, anchors, objectness, box_regression, targets):
"""
Arguments:
anchors (list[BoxList])
objectness (list[Tensor])
box_regression (list[Tensor])
targets (list[BoxList])
Returns:
objectness_loss (Tensor)
box_loss (Tensor
"""
anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
labels, regression_targets = self.prepare_targets(anchors, targets)
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
objectness, box_regression = \
concat_box_prediction_layers(objectness, box_regression)
objectness = objectness.squeeze() # [1041820]
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
box_loss = iou_regress(
box_regression[sampled_pos_inds],
regression_targets[sampled_pos_inds],
beta=1.0 / 9,
size_average=False,
) / (sampled_inds.numel())
box_loss *= cfg.MODEL.ROI_BOUNDARY_HEAD.Loss_balance
objectness_loss = F.binary_cross_entropy_with_logits(
objectness[sampled_inds], labels[sampled_inds]
)
return objectness_loss, box_loss
# This function should be overwritten in RetinaNet 11
def generate_rpn_labels(matched_targets):
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = matched_idxs >= 0
return labels_per_image
def make_rpn_loss_evaluator(cfg, box_coder):
matcher = Matcher(
cfg.MODEL.RPN.FG_IOU_THRESHOLD,
cfg.MODEL.RPN.BG_IOU_THRESHOLD,
allow_low_quality_matches=True,
)
fg_bg_sampler = BalancedPositiveNegativeSampler(
cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION
)
loss_evaluator = RPNLossComputation(
matcher,
fg_bg_sampler,
box_coder,
generate_rpn_labels
)
return loss_evaluator
|