AndreasLH's picture
upload repo
56bd2b5
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict, List, Tuple
import torch
from typing import List, Tuple, Union
import torch.nn.functional as F
from detectron2.config import configurable
from detectron2.utils.events import get_event_storage
from detectron2.layers import ShapeSpec, cat
from detectron2.structures import Boxes, Instances, pairwise_iou, pairwise_ioa
from detectron2.utils.memory import retry_if_cuda_oom
from fvcore.nn import smooth_l1_loss
from detectron2.layers import cat
from detectron2.layers import nonzero_tuple
from detectron2.modeling.box_regression import Box2BoxTransform, _dense_box_regression_loss
from detectron2.modeling.proposal_generator import RPN
from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY
@PROPOSAL_GENERATOR_REGISTRY.register()
class RPNWithIgnore(RPN):
@configurable
def __init__(
self,
*,
ignore_thresh: float = 0.5,
objectness_uncertainty: str = 'none',
**kwargs
):
super().__init__(**kwargs)
self.ignore_thresh = ignore_thresh
self.objectness_uncertainty = objectness_uncertainty
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
ret = super().from_config(cfg, input_shape)
ret["ignore_thresh"] = cfg.MODEL.RPN.IGNORE_THRESHOLD
ret["objectness_uncertainty"] = cfg.MODEL.RPN.OBJECTNESS_UNCERTAINTY
return ret
@torch.jit.unused
@torch.no_grad()
def label_and_sample_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
anchors = Boxes.cat(anchors)
# separate valid and ignore gts
gt_boxes_ign = [x.gt_boxes[x.gt_classes < 0] for x in gt_instances]
gt_boxes = [x.gt_boxes[x.gt_classes >= 0] for x in gt_instances]
del gt_instances
gt_labels = []
matched_gt_boxes = []
for gt_boxes_i, gt_boxes_ign_i in zip(gt_boxes, gt_boxes_ign):
"""
gt_boxes_i: ground-truth boxes for i-th image
gt_boxes_ign_i: ground-truth ignore boxes for i-th image
"""
match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors)
matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix)
# Matching is memory-expensive and may result in CPU tensors. But the result is small
gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device)
gt_arange = torch.arange(match_quality_matrix.shape[1]).to(matched_idxs.device)
matched_ious = match_quality_matrix[matched_idxs, gt_arange]
best_ious_gt_vals, best_ious_gt_ind = match_quality_matrix.max(dim=1)
del match_quality_matrix
best_inds = torch.tensor(list(set(best_ious_gt_ind.tolist()) & set((gt_labels_i == 1).nonzero().squeeze(1).tolist())))
# A vector of labels (-1, 0, 1) for each anchor
# which denote (ignore, background, foreground)
gt_labels_i = self._subsample_labels(gt_labels_i, matched_ious=matched_ious)
# overrride the best possible GT options, always selected for sampling.
# otherwise aggressive thresholds may produce HUGE amounts of low quality FG.
if best_inds.numel() > 0:
gt_labels_i[best_inds] = 1.0
if len(gt_boxes_i) == 0:
# These values won't be used anyway since the anchor is labeled as background
matched_gt_boxes_i = torch.zeros_like(anchors.tensor)
else:
# TODO wasted indexing computation for ignored boxes
matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor
if len(gt_boxes_ign_i) > 0:
# compute the quality matrix, only on subset of background
background_inds = (gt_labels_i == 0).nonzero().squeeze()
if background_inds.numel() > 1:
match_quality_matrix_ign = retry_if_cuda_oom(pairwise_ioa)(gt_boxes_ign_i, anchors[background_inds])
# determine the boxes inside ignore regions with sufficient threshold
gt_labels_i[background_inds[match_quality_matrix_ign.max(0)[0] >= self.ignore_thresh]] = -1
del match_quality_matrix_ign
gt_labels.append(gt_labels_i) # N,AHW
matched_gt_boxes.append(matched_gt_boxes_i)
return gt_labels, matched_gt_boxes
def _subsample_labels(self, label, matched_ious=None):
"""
Randomly sample a subset of positive and negative examples, and overwrite
the label vector to the ignore value (-1) for all elements that are not
included in the sample.
Args:
labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned.
"""
pos_idx, neg_idx = subsample_labels(
label, self.batch_size_per_image, self.positive_fraction, 0, matched_ious=matched_ious
)
# Fill with the ignore label (-1), then set positive and negative labels
label.fill_(-1)
label.scatter_(0, pos_idx, 1)
label.scatter_(0, neg_idx, 0)
return label
@torch.jit.unused
def losses(
self,
anchors: List[Boxes],
pred_objectness_logits: List[torch.Tensor],
gt_labels: List[torch.Tensor],
pred_anchor_deltas: List[torch.Tensor],
gt_boxes: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Return the losses from a set of RPN predictions and their associated ground-truth.
Args:
anchors (list[Boxes or RotatedBoxes]): anchors for each feature map, each
has shape (Hi*Wi*A, B), where B is box dimension (4 or 5).
pred_objectness_logits (list[Tensor]): A list of L elements.
Element i is a tensor of shape (N, Hi*Wi*A) representing
the predicted objectness logits for all anchors.
gt_labels (list[Tensor]): Output of :meth:`label_and_sample_anchors`.
pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape
(N, Hi*Wi*A, 4 or 5) representing the predicted "deltas" used to transform anchors
to proposals.
gt_boxes (list[Tensor]): Output of :meth:`label_and_sample_anchors`.
Returns:
dict[loss name -> loss value]: A dict mapping from loss name to loss value.
Loss names are: `loss_rpn_cls` for objectness classification and
`loss_rpn_loc` for proposal localization.
"""
num_images = len(gt_labels)
gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai))
# Log the number of positive/negative anchors per-image that's used in training
pos_mask = gt_labels == 1
num_pos_anchors = pos_mask.sum().item()
num_neg_anchors = (gt_labels == 0).sum().item()
storage = get_event_storage()
storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / num_images)
storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / num_images)
if not self.objectness_uncertainty.lower() in ['none']:
localization_loss, objectness_loss = _dense_box_regression_loss_with_uncertainty(
anchors,
self.box2box_transform,
pred_anchor_deltas,
pred_objectness_logits,
gt_boxes,
pos_mask,
box_reg_loss_type=self.box_reg_loss_type,
smooth_l1_beta=self.smooth_l1_beta,
uncertainty_type=self.objectness_uncertainty,
)
else:
localization_loss = _dense_box_regression_loss(
anchors,
self.box2box_transform,
pred_anchor_deltas,
gt_boxes,
pos_mask,
box_reg_loss_type=self.box_reg_loss_type,
smooth_l1_beta=self.smooth_l1_beta,
)
valid_mask = gt_labels >= 0
objectness_loss = F.binary_cross_entropy_with_logits(
cat(pred_objectness_logits, dim=1)[valid_mask],
gt_labels[valid_mask].to(torch.float32),
reduction="sum",
)
normalizer = self.batch_size_per_image * num_images
losses = {
"rpn/cls": objectness_loss / normalizer,
"rpn/loc": localization_loss / normalizer,
}
losses = {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
return losses
def _dense_box_regression_loss_with_uncertainty(
anchors: List[Union[Boxes, torch.Tensor]],
box2box_transform: Box2BoxTransform,
pred_anchor_deltas: List[torch.Tensor],
pred_objectness_logits: List[torch.Tensor],
gt_boxes: List[torch.Tensor],
fg_mask: torch.Tensor,
box_reg_loss_type="smooth_l1",
smooth_l1_beta=0.0,
uncertainty_type='centerness',
):
"""
Compute loss for dense multi-level box regression.
Loss is accumulated over ``fg_mask``.
Args:
anchors: #lvl anchor boxes, each is (HixWixA, 4)
pred_anchor_deltas: #lvl predictions, each is (N, HixWixA, 4)
gt_boxes: N ground truth boxes, each has shape (R, 4) (R = sum(Hi * Wi * A))
fg_mask: the foreground boolean mask of shape (N, R) to compute loss on
box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou",
"diou", "ciou".
smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to
use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1"
"""
if isinstance(anchors[0], Boxes):
anchors = type(anchors[0]).cat(anchors).tensor # (R, 4)
else:
anchors = cat(anchors)
n = len(gt_boxes)
boxes_fg = Boxes(anchors.unsqueeze(0).repeat([n, 1, 1])[fg_mask])
gt_boxes_fg = Boxes(torch.stack(gt_boxes)[fg_mask].detach())
objectness_targets_anchors = matched_pairwise_iou(boxes_fg, gt_boxes_fg).detach()
objectness_logits = torch.cat(pred_objectness_logits, dim=1)
# Numerically the same as (-(y*torch.log(p) + (1 - y)*torch.log(1 - p))).sum()
loss_box_conf = F.binary_cross_entropy_with_logits(
objectness_logits[fg_mask],
objectness_targets_anchors,
reduction='none'
)
loss_box_conf = (loss_box_conf * objectness_targets_anchors).sum()
# keep track of how scores look for FG / BG.
# ideally, FG slowly >>> BG scores as regression improves.
storage = get_event_storage()
storage.put_scalar("rpn/conf_pos_anchors", torch.sigmoid(objectness_logits[fg_mask]).mean().item())
storage.put_scalar("rpn/conf_neg_anchors", torch.sigmoid(objectness_logits[~fg_mask]).mean().item())
if box_reg_loss_type == "smooth_l1":
gt_anchor_deltas = [box2box_transform.get_deltas(anchors, k) for k in gt_boxes]
gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4)
loss_box_reg = smooth_l1_loss(
cat(pred_anchor_deltas, dim=1)[fg_mask],
gt_anchor_deltas[fg_mask],
beta=smooth_l1_beta,
reduction="none",
)
loss_box_reg = (loss_box_reg.sum(dim=1) * objectness_targets_anchors).sum()
else:
raise ValueError(f"Invalid dense box regression loss type '{box_reg_loss_type}'")
return loss_box_reg, loss_box_conf
def subsample_labels(
labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int, matched_ious=None, eps=1e-4
):
"""
Return `num_samples` (or fewer, if not enough found)
random samples from `labels` which is a mixture of positives & negatives.
It will try to return as many positives as possible without
exceeding `positive_fraction * num_samples`, and then try to
fill the remaining slots with negatives.
Args:
labels (Tensor): (N, ) label vector with values:
* -1: ignore
* bg_label: background ("negative") class
* otherwise: one or more foreground ("positive") classes
num_samples (int): The total number of labels with value >= 0 to return.
Values that are not sampled will be filled with -1 (ignore).
positive_fraction (float): The number of subsampled labels with values > 0
is `min(num_positives, int(positive_fraction * num_samples))`. The number
of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
In order words, if there are not enough positives, the sample is filled with
negatives. If there are also not enough negatives, then as many elements are
sampled as is possible.
bg_label (int): label index of background ("negative") class.
Returns:
pos_idx, neg_idx (Tensor):
1D vector of indices. The total length of both is `num_samples` or fewer.
"""
positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0]
negative = nonzero_tuple(labels == bg_label)[0]
num_pos = int(num_samples * positive_fraction)
# protect against not enough positive examples
num_pos = min(positive.numel(), num_pos)
num_neg = num_samples - num_pos
# protect against not enough negative examples
num_neg = min(negative.numel(), num_neg)
#if positive_fraction == 1.0 and num_neg > 10:
# allow some negatives for statistics only.
#num_neg = 10
# randomly select positive and negative examples
if num_pos > 0 and matched_ious is not None:
perm1 = torch.multinomial(matched_ious[positive] + eps, num_pos)
else:
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
if num_neg > 0 and matched_ious is not None:
perm2 = torch.multinomial(matched_ious[negative] + eps, num_neg)
else:
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
pos_idx = positive[perm1]
neg_idx = negative[perm2]
return pos_idx, neg_idx
def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
"""
Compute pairwise intersection over union (IOU) of two sets of matched
boxes that have the same number of boxes.
Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix.
Args:
boxes1 (Boxes): bounding boxes, sized [N,4].
boxes2 (Boxes): same length as boxes1
Returns:
Tensor: iou, sized [N].
"""
assert len(boxes1) == len(
boxes2
), "boxlists should have the same" "number of entries, got {}, {}".format(
len(boxes1), len(boxes2)
)
area1 = boxes1.area() # [N]
area2 = boxes2.area() # [N]
box1, box2 = boxes1.tensor, boxes2.tensor
lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2]
rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2]
wh = (rb - lt).clamp(min=0) # [N,2]
inter = wh[:, 0] * wh[:, 1] # [N]
iou = inter / (area1 + area2 - inter) # [N]
return iou