Spaces:
Build error
Build error
File size: 6,923 Bytes
6250360 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import torch
from ..inference import RPNPostProcessor
from ..utils import permute_and_flatten
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
class RetinaNetPostProcessor(RPNPostProcessor):
"""
Performs post-processing on the outputs of the RetinaNet boxes.
This is only used in the testing.
"""
def __init__(
self,
pre_nms_thresh,
pre_nms_top_n,
nms_thresh,
fpn_post_nms_top_n,
min_size,
num_classes,
box_coder=None,
):
"""
Arguments:
pre_nms_thresh (float)
pre_nms_top_n (int)
nms_thresh (float)
fpn_post_nms_top_n (int)
min_size (int)
num_classes (int)
box_coder (BoxCoder)
"""
super(RetinaNetPostProcessor, self).__init__(
pre_nms_thresh, 0, nms_thresh, min_size
)
self.pre_nms_thresh = pre_nms_thresh
self.pre_nms_top_n = pre_nms_top_n
self.nms_thresh = nms_thresh
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.min_size = min_size
self.num_classes = num_classes
if box_coder is None:
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
self.box_coder = box_coder
def add_gt_proposals(self, proposals, targets):
"""
This function is not used in RetinaNet
"""
pass
def forward_for_single_feature_map(
self, anchors, box_cls, box_regression):
"""
Arguments:
anchors: list[BoxList]
box_cls: tensor of size N, A * C, H, W
box_regression: tensor of size N, A * 4, H, W
"""
device = box_cls.device
N, _, H, W = box_cls.shape
A = box_regression.size(1) // 4
C = box_cls.size(1) // A
# put in the same format as anchors
box_cls = permute_and_flatten(box_cls, N, A, C, H, W)
box_cls = box_cls.sigmoid()
box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
box_regression = box_regression.reshape(N, -1, 4)
num_anchors = A * H * W
candidate_inds = box_cls > self.pre_nms_thresh
pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
results = []
for per_box_cls, per_box_regression, per_pre_nms_top_n, \
per_candidate_inds, per_anchors in zip(
box_cls,
box_regression,
pre_nms_top_n,
candidate_inds,
anchors):
# Sort and select TopN
# TODO most of this can be made out of the loop for
# all images.
# TODO:Yang: Not easy to do. Because the numbers of detections are
# different in each image. Therefore, this part needs to be done
# per image.
per_box_cls = per_box_cls[per_candidate_inds]
per_box_cls, top_k_indices = \
per_box_cls.topk(per_pre_nms_top_n, sorted=False)
per_candidate_nonzeros = \
per_candidate_inds.nonzero()[top_k_indices, :]
per_box_loc = per_candidate_nonzeros[:, 0]
per_class = per_candidate_nonzeros[:, 1]
per_class += 1
detections = self.box_coder.decode(
per_box_regression[per_box_loc, :].view(-1, 4),
per_anchors.bbox[per_box_loc, :].view(-1, 4)
)
boxlist = BoxList(detections, per_anchors.size, mode="xyxy")
boxlist.add_field("labels", per_class)
boxlist.add_field("scores", per_box_cls)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
results.append(boxlist)
return results
# TODO very similar to filter_results from PostProcessor
# but filter_results is per image
# TODO Yang: solve this issue in the future. No good solution
# right now.
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
results = []
for i in range(num_images):
scores = boxlists[i].get_field("scores")
labels = boxlists[i].get_field("labels")
boxes = boxlists[i].bbox
boxlist = boxlists[i]
result = []
# skip the background
for j in range(1, self.num_classes):
inds = (labels == j).nonzero().view(-1)
scores_j = scores[inds]
boxes_j = boxes[inds, :].view(-1, 4)
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
boxlist_for_class.add_field("scores", scores_j)
boxlist_for_class = boxlist_nms(
boxlist_for_class, self.nms_thresh,
score_field="scores"
)
num_labels = len(boxlist_for_class)
boxlist_for_class.add_field(
"labels", torch.full((num_labels,), j,
dtype=torch.int64,
device=scores.device)
)
result.append(boxlist_for_class)
result = cat_boxlist(result)
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
if number_of_detections > self.fpn_post_nms_top_n > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = torch.kthvalue(
cls_scores.cpu(),
number_of_detections - self.fpn_post_nms_top_n + 1
)
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
results.append(result)
return results
def make_retinanet_postprocessor(config, rpn_box_coder, is_train):
pre_nms_thresh = config.MODEL.RETINANET.INFERENCE_TH
pre_nms_top_n = config.MODEL.RETINANET.PRE_NMS_TOP_N
nms_thresh = config.MODEL.RETINANET.NMS_TH
fpn_post_nms_top_n = config.TEST.DETECTIONS_PER_IMG
min_size = 0
box_selector = RetinaNetPostProcessor(
pre_nms_thresh=pre_nms_thresh,
pre_nms_top_n=pre_nms_top_n,
nms_thresh=nms_thresh,
fpn_post_nms_top_n=fpn_post_nms_top_n,
min_size=min_size,
num_classes=config.MODEL.RETINANET.NUM_CLASSES,
box_coder=rpn_box_coder,
)
return box_selector
|