Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates | |
from typing import Dict, List, Tuple | |
import torch | |
from typing import List, Tuple, Union | |
import torch.nn.functional as F | |
from detectron2.config import configurable | |
from detectron2.utils.events import get_event_storage | |
from detectron2.layers import ShapeSpec, cat | |
from detectron2.structures import Boxes, Instances, pairwise_iou, pairwise_ioa | |
from detectron2.utils.memory import retry_if_cuda_oom | |
from fvcore.nn import smooth_l1_loss | |
from detectron2.layers import cat | |
from detectron2.layers import nonzero_tuple | |
from detectron2.modeling.box_regression import Box2BoxTransform, _dense_box_regression_loss | |
from detectron2.modeling.proposal_generator import RPN | |
from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY | |
class RPNWithIgnore(RPN): | |
def __init__( | |
self, | |
*, | |
ignore_thresh: float = 0.5, | |
objectness_uncertainty: str = 'none', | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.ignore_thresh = ignore_thresh | |
self.objectness_uncertainty = objectness_uncertainty | |
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): | |
ret = super().from_config(cfg, input_shape) | |
ret["ignore_thresh"] = cfg.MODEL.RPN.IGNORE_THRESHOLD | |
ret["objectness_uncertainty"] = cfg.MODEL.RPN.OBJECTNESS_UNCERTAINTY | |
return ret | |
def label_and_sample_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: | |
anchors = Boxes.cat(anchors) | |
# separate valid and ignore gts | |
gt_boxes_ign = [x.gt_boxes[x.gt_classes < 0] for x in gt_instances] | |
gt_boxes = [x.gt_boxes[x.gt_classes >= 0] for x in gt_instances] | |
del gt_instances | |
gt_labels = [] | |
matched_gt_boxes = [] | |
for gt_boxes_i, gt_boxes_ign_i in zip(gt_boxes, gt_boxes_ign): | |
""" | |
gt_boxes_i: ground-truth boxes for i-th image | |
gt_boxes_ign_i: ground-truth ignore boxes for i-th image | |
""" | |
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors) | |
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix) | |
# Matching is memory-expensive and may result in CPU tensors. But the result is small | |
gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device) | |
gt_arange = torch.arange(match_quality_matrix.shape[1]).to(matched_idxs.device) | |
matched_ious = match_quality_matrix[matched_idxs, gt_arange] | |
best_ious_gt_vals, best_ious_gt_ind = match_quality_matrix.max(dim=1) | |
del match_quality_matrix | |
best_inds = torch.tensor(list(set(best_ious_gt_ind.tolist()) & set((gt_labels_i == 1).nonzero().squeeze(1).tolist()))) | |
# A vector of labels (-1, 0, 1) for each anchor | |
# which denote (ignore, background, foreground) | |
gt_labels_i = self._subsample_labels(gt_labels_i, matched_ious=matched_ious) | |
# overrride the best possible GT options, always selected for sampling. | |
# otherwise aggressive thresholds may produce HUGE amounts of low quality FG. | |
if best_inds.numel() > 0: | |
gt_labels_i[best_inds] = 1.0 | |
if len(gt_boxes_i) == 0: | |
# These values won't be used anyway since the anchor is labeled as background | |
matched_gt_boxes_i = torch.zeros_like(anchors.tensor) | |
else: | |
# TODO wasted indexing computation for ignored boxes | |
matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor | |
if len(gt_boxes_ign_i) > 0: | |
# compute the quality matrix, only on subset of background | |
background_inds = (gt_labels_i == 0).nonzero().squeeze() | |
if background_inds.numel() > 1: | |
match_quality_matrix_ign = retry_if_cuda_oom(pairwise_ioa)(gt_boxes_ign_i, anchors[background_inds]) | |
# determine the boxes inside ignore regions with sufficient threshold | |
gt_labels_i[background_inds[match_quality_matrix_ign.max(0)[0] >= self.ignore_thresh]] = -1 | |
del match_quality_matrix_ign | |
gt_labels.append(gt_labels_i) # N,AHW | |
matched_gt_boxes.append(matched_gt_boxes_i) | |
return gt_labels, matched_gt_boxes | |
def _subsample_labels(self, label, matched_ious=None): | |
""" | |
Randomly sample a subset of positive and negative examples, and overwrite | |
the label vector to the ignore value (-1) for all elements that are not | |
included in the sample. | |
Args: | |
labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned. | |
""" | |
pos_idx, neg_idx = subsample_labels( | |
label, self.batch_size_per_image, self.positive_fraction, 0, matched_ious=matched_ious | |
) | |
# Fill with the ignore label (-1), then set positive and negative labels | |
label.fill_(-1) | |
label.scatter_(0, pos_idx, 1) | |
label.scatter_(0, neg_idx, 0) | |
return label | |
def losses( | |
self, | |
anchors: List[Boxes], | |
pred_objectness_logits: List[torch.Tensor], | |
gt_labels: List[torch.Tensor], | |
pred_anchor_deltas: List[torch.Tensor], | |
gt_boxes: List[torch.Tensor], | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Return the losses from a set of RPN predictions and their associated ground-truth. | |
Args: | |
anchors (list[Boxes or RotatedBoxes]): anchors for each feature map, each | |
has shape (Hi*Wi*A, B), where B is box dimension (4 or 5). | |
pred_objectness_logits (list[Tensor]): A list of L elements. | |
Element i is a tensor of shape (N, Hi*Wi*A) representing | |
the predicted objectness logits for all anchors. | |
gt_labels (list[Tensor]): Output of :meth:`label_and_sample_anchors`. | |
pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape | |
(N, Hi*Wi*A, 4 or 5) representing the predicted "deltas" used to transform anchors | |
to proposals. | |
gt_boxes (list[Tensor]): Output of :meth:`label_and_sample_anchors`. | |
Returns: | |
dict[loss name -> loss value]: A dict mapping from loss name to loss value. | |
Loss names are: `loss_rpn_cls` for objectness classification and | |
`loss_rpn_loc` for proposal localization. | |
""" | |
num_images = len(gt_labels) | |
gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai)) | |
# Log the number of positive/negative anchors per-image that's used in training | |
pos_mask = gt_labels == 1 | |
num_pos_anchors = pos_mask.sum().item() | |
num_neg_anchors = (gt_labels == 0).sum().item() | |
storage = get_event_storage() | |
storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / num_images) | |
storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / num_images) | |
if not self.objectness_uncertainty.lower() in ['none']: | |
localization_loss, objectness_loss = _dense_box_regression_loss_with_uncertainty( | |
anchors, | |
self.box2box_transform, | |
pred_anchor_deltas, | |
pred_objectness_logits, | |
gt_boxes, | |
pos_mask, | |
box_reg_loss_type=self.box_reg_loss_type, | |
smooth_l1_beta=self.smooth_l1_beta, | |
uncertainty_type=self.objectness_uncertainty, | |
) | |
else: | |
localization_loss = _dense_box_regression_loss( | |
anchors, | |
self.box2box_transform, | |
pred_anchor_deltas, | |
gt_boxes, | |
pos_mask, | |
box_reg_loss_type=self.box_reg_loss_type, | |
smooth_l1_beta=self.smooth_l1_beta, | |
) | |
valid_mask = gt_labels >= 0 | |
objectness_loss = F.binary_cross_entropy_with_logits( | |
cat(pred_objectness_logits, dim=1)[valid_mask], | |
gt_labels[valid_mask].to(torch.float32), | |
reduction="sum", | |
) | |
normalizer = self.batch_size_per_image * num_images | |
losses = { | |
"rpn/cls": objectness_loss / normalizer, | |
"rpn/loc": localization_loss / normalizer, | |
} | |
losses = {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()} | |
return losses | |
def _dense_box_regression_loss_with_uncertainty( | |
anchors: List[Union[Boxes, torch.Tensor]], | |
box2box_transform: Box2BoxTransform, | |
pred_anchor_deltas: List[torch.Tensor], | |
pred_objectness_logits: List[torch.Tensor], | |
gt_boxes: List[torch.Tensor], | |
fg_mask: torch.Tensor, | |
box_reg_loss_type="smooth_l1", | |
smooth_l1_beta=0.0, | |
uncertainty_type='centerness', | |
): | |
""" | |
Compute loss for dense multi-level box regression. | |
Loss is accumulated over ``fg_mask``. | |
Args: | |
anchors: #lvl anchor boxes, each is (HixWixA, 4) | |
pred_anchor_deltas: #lvl predictions, each is (N, HixWixA, 4) | |
gt_boxes: N ground truth boxes, each has shape (R, 4) (R = sum(Hi * Wi * A)) | |
fg_mask: the foreground boolean mask of shape (N, R) to compute loss on | |
box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou", | |
"diou", "ciou". | |
smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to | |
use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1" | |
""" | |
if isinstance(anchors[0], Boxes): | |
anchors = type(anchors[0]).cat(anchors).tensor # (R, 4) | |
else: | |
anchors = cat(anchors) | |
n = len(gt_boxes) | |
boxes_fg = Boxes(anchors.unsqueeze(0).repeat([n, 1, 1])[fg_mask]) | |
gt_boxes_fg = Boxes(torch.stack(gt_boxes)[fg_mask].detach()) | |
objectness_targets_anchors = matched_pairwise_iou(boxes_fg, gt_boxes_fg).detach() | |
objectness_logits = torch.cat(pred_objectness_logits, dim=1) | |
# Numerically the same as (-(y*torch.log(p) + (1 - y)*torch.log(1 - p))).sum() | |
loss_box_conf = F.binary_cross_entropy_with_logits( | |
objectness_logits[fg_mask], | |
objectness_targets_anchors, | |
reduction='none' | |
) | |
loss_box_conf = (loss_box_conf * objectness_targets_anchors).sum() | |
# keep track of how scores look for FG / BG. | |
# ideally, FG slowly >>> BG scores as regression improves. | |
storage = get_event_storage() | |
storage.put_scalar("rpn/conf_pos_anchors", torch.sigmoid(objectness_logits[fg_mask]).mean().item()) | |
storage.put_scalar("rpn/conf_neg_anchors", torch.sigmoid(objectness_logits[~fg_mask]).mean().item()) | |
if box_reg_loss_type == "smooth_l1": | |
gt_anchor_deltas = [box2box_transform.get_deltas(anchors, k) for k in gt_boxes] | |
gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4) | |
loss_box_reg = smooth_l1_loss( | |
cat(pred_anchor_deltas, dim=1)[fg_mask], | |
gt_anchor_deltas[fg_mask], | |
beta=smooth_l1_beta, | |
reduction="none", | |
) | |
loss_box_reg = (loss_box_reg.sum(dim=1) * objectness_targets_anchors).sum() | |
else: | |
raise ValueError(f"Invalid dense box regression loss type '{box_reg_loss_type}'") | |
return loss_box_reg, loss_box_conf | |
def subsample_labels( | |
labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int, matched_ious=None, eps=1e-4 | |
): | |
""" | |
Return `num_samples` (or fewer, if not enough found) | |
random samples from `labels` which is a mixture of positives & negatives. | |
It will try to return as many positives as possible without | |
exceeding `positive_fraction * num_samples`, and then try to | |
fill the remaining slots with negatives. | |
Args: | |
labels (Tensor): (N, ) label vector with values: | |
* -1: ignore | |
* bg_label: background ("negative") class | |
* otherwise: one or more foreground ("positive") classes | |
num_samples (int): The total number of labels with value >= 0 to return. | |
Values that are not sampled will be filled with -1 (ignore). | |
positive_fraction (float): The number of subsampled labels with values > 0 | |
is `min(num_positives, int(positive_fraction * num_samples))`. The number | |
of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`. | |
In order words, if there are not enough positives, the sample is filled with | |
negatives. If there are also not enough negatives, then as many elements are | |
sampled as is possible. | |
bg_label (int): label index of background ("negative") class. | |
Returns: | |
pos_idx, neg_idx (Tensor): | |
1D vector of indices. The total length of both is `num_samples` or fewer. | |
""" | |
positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0] | |
negative = nonzero_tuple(labels == bg_label)[0] | |
num_pos = int(num_samples * positive_fraction) | |
# protect against not enough positive examples | |
num_pos = min(positive.numel(), num_pos) | |
num_neg = num_samples - num_pos | |
# protect against not enough negative examples | |
num_neg = min(negative.numel(), num_neg) | |
#if positive_fraction == 1.0 and num_neg > 10: | |
# allow some negatives for statistics only. | |
#num_neg = 10 | |
# randomly select positive and negative examples | |
if num_pos > 0 and matched_ious is not None: | |
perm1 = torch.multinomial(matched_ious[positive] + eps, num_pos) | |
else: | |
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] | |
if num_neg > 0 and matched_ious is not None: | |
perm2 = torch.multinomial(matched_ious[negative] + eps, num_neg) | |
else: | |
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] | |
pos_idx = positive[perm1] | |
neg_idx = negative[perm2] | |
return pos_idx, neg_idx | |
def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: | |
""" | |
Compute pairwise intersection over union (IOU) of two sets of matched | |
boxes that have the same number of boxes. | |
Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix. | |
Args: | |
boxes1 (Boxes): bounding boxes, sized [N,4]. | |
boxes2 (Boxes): same length as boxes1 | |
Returns: | |
Tensor: iou, sized [N]. | |
""" | |
assert len(boxes1) == len( | |
boxes2 | |
), "boxlists should have the same" "number of entries, got {}, {}".format( | |
len(boxes1), len(boxes2) | |
) | |
area1 = boxes1.area() # [N] | |
area2 = boxes2.area() # [N] | |
box1, box2 = boxes1.tensor, boxes2.tensor | |
lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2] | |
rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2] | |
wh = (rb - lt).clamp(min=0) # [N,2] | |
inter = wh[:, 0] * wh[:, 1] # [N] | |
iou = inter / (area1 + area2 - inter) # [N] | |
return iou |