Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
from maskrcnn_benchmark.modeling.box_coder import BoxCoder | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist | |
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms | |
from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes | |
from ..utils import cat | |
from .utils import permute_and_flatten | |
class RPNPostProcessor(torch.nn.Module): | |
""" | |
Performs post-processing on the outputs of the RPN boxes, before feeding the | |
proposals to the heads | |
""" | |
def __init__( | |
self, | |
pre_nms_top_n, | |
post_nms_top_n, | |
nms_thresh, | |
min_size, | |
box_coder=None, | |
fpn_post_nms_top_n=None, | |
): | |
""" | |
Arguments: | |
pre_nms_top_n (int) | |
post_nms_top_n (int) | |
nms_thresh (float) | |
min_size (int) | |
box_coder (BoxCoder) | |
fpn_post_nms_top_n (int) | |
""" | |
super(RPNPostProcessor, self).__init__() | |
self.pre_nms_top_n = pre_nms_top_n # 12000 | |
self.post_nms_top_n = post_nms_top_n # 2000 | |
self.nms_thresh = nms_thresh # 0.7 | |
self.min_size = min_size # 0 | |
if box_coder is None: | |
box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) | |
self.box_coder = box_coder | |
if fpn_post_nms_top_n is None: | |
fpn_post_nms_top_n = post_nms_top_n | |
self.fpn_post_nms_top_n = fpn_post_nms_top_n # 2000 | |
def add_gt_proposals(self, proposals, targets): | |
""" | |
Arguments: | |
proposals: list[BoxList] | |
targets: list[BoxList] | |
""" | |
# Get the device we're operating on | |
device = proposals[0].bbox.device | |
gt_boxes = [target.copy_with_fields([]) for target in targets] | |
# later cat of bbox requires all fields to be present for all bbox | |
# so we need to add a dummy for objectness that's missing | |
for gt_box in gt_boxes: | |
gt_box.add_field("objectness", torch.ones(len(gt_box), device=device)) | |
proposals = [ | |
cat_boxlist((proposal, gt_box)) | |
for proposal, gt_box in zip(proposals, gt_boxes) | |
] | |
return proposals | |
def forward_for_single_feature_map(self, anchors, objectness, box_regression): | |
""" | |
Arguments: | |
anchors: list[BoxList] # [image,number,[n,4]] | |
objectness: tensor of size N, A, H, W | |
box_regression: tensor of size N, A * 4, H, W | |
""" | |
device = objectness.device | |
N, A, H, W = objectness.shape | |
# put in the same format as anchors | |
objectness = permute_and_flatten(objectness, N, A, 1, H, W).view(N, -1) # N H*W*A*1 | |
objectness = objectness.sigmoid() | |
box_regression = permute_and_flatten(box_regression, N, A, 18, H, W) # N H*W*A 4 | |
num_anchors = A * H * W # 391040 97760 | |
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) #12000 | |
objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True) | |
# objectness = objectness.cpu() | |
batch_idx = torch.arange(N, device=device)[:, None] | |
box_regression = box_regression[batch_idx, topk_idx] | |
image_shapes = [box.size for box in anchors] | |
concat_anchors = torch.cat([a.bbox for a in anchors], dim=0) | |
concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx, topk_idx] | |
proposals = self.box_coder.decode_iou( | |
box_regression.view(-1, 18), concat_anchors.view(-1, 4) | |
) | |
proposals = proposals.view(N, -1, 4) | |
result = [] | |
for proposal, score, im_shape in zip(proposals, objectness, image_shapes): | |
boxlist = BoxList(proposal, im_shape, mode="xyxy") | |
boxlist.add_field("objectness", score) | |
boxlist = boxlist.clip_to_image(remove_empty=False) | |
boxlist = remove_small_boxes(boxlist, self.min_size) | |
boxlist = boxlist_nms( | |
boxlist, | |
self.nms_thresh, | |
max_proposals=self.post_nms_top_n, | |
score_field="objectness", | |
) | |
result.append(boxlist) | |
return result | |
def forward(self, anchors, objectness, box_regression, targets=None): | |
""" | |
Arguments: | |
anchors: list[list[BoxList]] | |
objectness: list[tensor] | |
box_regression: list[tensor] | |
Returns: | |
boxlists (list[BoxList]): the post-processed anchors, after | |
applying box decoding and NMS | |
""" | |
sampled_boxes = [] | |
num_levels = len(objectness) # classification | |
anchors = list(zip(*anchors)) # [image,number,[n,4]] | |
# i =-1 | |
for a, o, b in zip(anchors, objectness, box_regression): | |
sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) | |
boxlists = list(zip(*sampled_boxes)) | |
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] | |
if num_levels > 1: | |
boxlists = self.select_over_all_levels(boxlists) | |
# append ground-truth bboxes to proposals | |
if self.training and targets is not None: | |
boxlists = self.add_gt_proposals(boxlists, targets) | |
return boxlists | |
def select_over_all_levels(self, boxlists): | |
num_images = len(boxlists) | |
# different behavior during training and during testing: | |
# during training, post_nms_top_n is over *all* the proposals combined, while | |
# during testing, it is over the proposals for each image | |
# TODO resolve this difference and make it consistent. It should be per image, | |
# and not per batch | |
if self.training: | |
objectness = torch.cat( | |
[boxlist.get_field("objectness") for boxlist in boxlists], dim=0 | |
) | |
box_sizes = [len(boxlist) for boxlist in boxlists] | |
post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) | |
_, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True) | |
inds_mask = torch.zeros_like(objectness, dtype=torch.uint8) | |
inds_mask[inds_sorted] = 1 | |
inds_mask = inds_mask.split(box_sizes) | |
for i in range(num_images): | |
boxlists[i] = boxlists[i][inds_mask[i]] | |
else: | |
for i in range(num_images): | |
objectness = boxlists[i].get_field("objectness") | |
post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) | |
_, inds_sorted = torch.topk( | |
objectness, post_nms_top_n, dim=0, sorted=True | |
) | |
boxlists[i] = boxlists[i][inds_sorted] | |
return boxlists | |
def make_rpn_postprocessor(config, rpn_box_coder, is_train): | |
fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN # 2000 | |
if not is_train: | |
fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST | |
pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TRAIN # 12000 | |
post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TRAIN # 2000 | |
if not is_train: | |
pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST | |
post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST | |
nms_thresh = config.MODEL.RPN.NMS_THRESH # 0.7 | |
min_size = config.MODEL.RPN.MIN_SIZE # 0 | |
box_selector = RPNPostProcessor( | |
pre_nms_top_n=pre_nms_top_n, #12000 | |
post_nms_top_n=post_nms_top_n, #2000 | |
nms_thresh=nms_thresh, # 0.7 | |
min_size=min_size, # 0 | |
box_coder=rpn_box_coder, | |
fpn_post_nms_top_n=fpn_post_nms_top_n, #2000 | |
) | |
return box_selector | |