Cyril666's picture
First model version
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
from ..utils import cat
from .utils import permute_and_flatten
class RPNPostProcessor(torch.nn.Module):
Performs post-processing on the outputs of the RPN boxes, before feeding the
proposals to the heads
def __init__(
pre_nms_top_n (int)
post_nms_top_n (int)
nms_thresh (float)
min_size (int)
box_coder (BoxCoder)
fpn_post_nms_top_n (int)
super(RPNPostProcessor, self).__init__()
self.pre_nms_top_n = pre_nms_top_n # 12000
self.post_nms_top_n = post_nms_top_n # 2000
self.nms_thresh = nms_thresh # 0.7
self.min_size = min_size # 0
if box_coder is None:
box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
self.box_coder = box_coder
if fpn_post_nms_top_n is None:
fpn_post_nms_top_n = post_nms_top_n
self.fpn_post_nms_top_n = fpn_post_nms_top_n # 2000
def add_gt_proposals(self, proposals, targets):
proposals: list[BoxList]
targets: list[BoxList]
# Get the device we're operating on
device = proposals[0].bbox.device
gt_boxes = [target.copy_with_fields([]) for target in targets]
# later cat of bbox requires all fields to be present for all bbox
# so we need to add a dummy for objectness that's missing
for gt_box in gt_boxes:
gt_box.add_field("objectness", torch.ones(len(gt_box), device=device))
proposals = [
cat_boxlist((proposal, gt_box))
for proposal, gt_box in zip(proposals, gt_boxes)
return proposals
def forward_for_single_feature_map(self, anchors, objectness, box_regression):
anchors: list[BoxList] # [image,number,[n,4]]
objectness: tensor of size N, A, H, W
box_regression: tensor of size N, A * 4, H, W
device = objectness.device
N, A, H, W = objectness.shape
# put in the same format as anchors
objectness = permute_and_flatten(objectness, N, A, 1, H, W).view(N, -1) # N H*W*A*1
objectness = objectness.sigmoid()
box_regression = permute_and_flatten(box_regression, N, A, 18, H, W) # N H*W*A 4
num_anchors = A * H * W # 391040 97760
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) #12000
objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True)
# objectness = objectness.cpu()
batch_idx = torch.arange(N, device=device)[:, None]
box_regression = box_regression[batch_idx, topk_idx]
image_shapes = [box.size for box in anchors]
concat_anchors =[a.bbox for a in anchors], dim=0)
concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx, topk_idx]
proposals = self.box_coder.decode_iou(
box_regression.view(-1, 18), concat_anchors.view(-1, 4)
proposals = proposals.view(N, -1, 4)
result = []
for proposal, score, im_shape in zip(proposals, objectness, image_shapes):
boxlist = BoxList(proposal, im_shape, mode="xyxy")
boxlist.add_field("objectness", score)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
boxlist = boxlist_nms(
return result
def forward(self, anchors, objectness, box_regression, targets=None):
anchors: list[list[BoxList]]
objectness: list[tensor]
box_regression: list[tensor]
boxlists (list[BoxList]): the post-processed anchors, after
applying box decoding and NMS
sampled_boxes = []
num_levels = len(objectness) # classification
anchors = list(zip(*anchors)) # [image,number,[n,4]]
# i =-1
for a, o, b in zip(anchors, objectness, box_regression):
sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
boxlists = list(zip(*sampled_boxes))
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
if num_levels > 1:
boxlists = self.select_over_all_levels(boxlists)
# append ground-truth bboxes to proposals
if and targets is not None:
boxlists = self.add_gt_proposals(boxlists, targets)
return boxlists
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
# different behavior during training and during testing:
# during training, post_nms_top_n is over *all* the proposals combined, while
# during testing, it is over the proposals for each image
# TODO resolve this difference and make it consistent. It should be per image,
# and not per batch
objectness =
[boxlist.get_field("objectness") for boxlist in boxlists], dim=0
box_sizes = [len(boxlist) for boxlist in boxlists]
post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
_, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True)
inds_mask = torch.zeros_like(objectness, dtype=torch.uint8)
inds_mask[inds_sorted] = 1
inds_mask = inds_mask.split(box_sizes)
for i in range(num_images):
boxlists[i] = boxlists[i][inds_mask[i]]
for i in range(num_images):
objectness = boxlists[i].get_field("objectness")
post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
_, inds_sorted = torch.topk(
objectness, post_nms_top_n, dim=0, sorted=True
boxlists[i] = boxlists[i][inds_sorted]
return boxlists
def make_rpn_postprocessor(config, rpn_box_coder, is_train):
fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN # 2000
if not is_train:
fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST
pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TRAIN # 12000
post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TRAIN # 2000
if not is_train:
pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST
post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST
nms_thresh = config.MODEL.RPN.NMS_THRESH # 0.7
min_size = config.MODEL.RPN.MIN_SIZE # 0
box_selector = RPNPostProcessor(
pre_nms_top_n=pre_nms_top_n, #12000
post_nms_top_n=post_nms_top_n, #2000
nms_thresh=nms_thresh, # 0.7
min_size=min_size, # 0
fpn_post_nms_top_n=fpn_post_nms_top_n, #2000
return box_selector