File size: 7,451 Bytes
e8f2571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# Copyright (c) OpenMMLab. All rights reserved.
# TODO: delete this file after refactor
import sys
import torch
from mmdet.models.layers import multiclass_nms
from mmdet.models.test_time_augs import merge_aug_bboxes, merge_aug_masks
from mmdet.structures.bbox import bbox2roi, bbox_mapping
if sys.version_info >= (3, 7):
from mmdet.utils.contextmanagers import completed
class BBoxTestMixin:
if sys.version_info >= (3, 7):
# TODO: Currently not supported
async def async_test_bboxes(self,
x,
img_metas,
proposals,
rcnn_test_cfg,
rescale=False,
**kwargs):
"""Asynchronized test for box head without augmentation."""
rois = bbox2roi(proposals)
roi_feats = self.bbox_roi_extractor(
x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
if self.with_shared_head:
roi_feats = self.shared_head(roi_feats)
sleep_interval = rcnn_test_cfg.get('async_sleep_interval', 0.017)
async with completed(
__name__, 'bbox_head_forward',
sleep_interval=sleep_interval):
cls_score, bbox_pred = self.bbox_head(roi_feats)
img_shape = img_metas[0]['img_shape']
scale_factor = img_metas[0]['scale_factor']
det_bboxes, det_labels = self.bbox_head.get_bboxes(
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=rescale,
cfg=rcnn_test_cfg)
return det_bboxes, det_labels
# TODO: Currently not supported
def aug_test_bboxes(self, feats, img_metas, rpn_results_list,
rcnn_test_cfg):
"""Test det bboxes with test time augmentation."""
aug_bboxes = []
aug_scores = []
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
flip_direction = img_meta[0]['flip_direction']
# TODO more flexible
proposals = bbox_mapping(rpn_results_list[0][:, :4], img_shape,
scale_factor, flip, flip_direction)
rois = bbox2roi([proposals])
bbox_results = self.bbox_forward(x, rois)
bboxes, scores = self.bbox_head.get_bboxes(
rois,
bbox_results['cls_score'],
bbox_results['bbox_pred'],
img_shape,
scale_factor,
rescale=False,
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
if merged_bboxes.shape[0] == 0:
# There is no proposal in the single image
det_bboxes = merged_bboxes.new_zeros(0, 5)
det_labels = merged_bboxes.new_zeros((0, ), dtype=torch.long)
else:
det_bboxes, det_labels = multiclass_nms(merged_bboxes,
merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)
return det_bboxes, det_labels
class MaskTestMixin:
if sys.version_info >= (3, 7):
# TODO: Currently not supported
async def async_test_mask(self,
x,
img_metas,
det_bboxes,
det_labels,
rescale=False,
mask_test_cfg=None):
"""Asynchronized test for mask head without augmentation."""
# image shape of the first image in the batch (only one)
ori_shape = img_metas[0]['ori_shape']
scale_factor = img_metas[0]['scale_factor']
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes)]
else:
if rescale and not isinstance(scale_factor,
(float, torch.Tensor)):
scale_factor = det_bboxes.new_tensor(scale_factor)
_bboxes = (
det_bboxes[:, :4] *
scale_factor if rescale else det_bboxes)
mask_rois = bbox2roi([_bboxes])
mask_feats = self.mask_roi_extractor(
x[:len(self.mask_roi_extractor.featmap_strides)],
mask_rois)
if self.with_shared_head:
mask_feats = self.shared_head(mask_feats)
if mask_test_cfg and \
mask_test_cfg.get('async_sleep_interval'):
sleep_interval = mask_test_cfg['async_sleep_interval']
else:
sleep_interval = 0.035
async with completed(
__name__,
'mask_head_forward',
sleep_interval=sleep_interval):
mask_pred = self.mask_head(mask_feats)
segm_result = self.mask_head.get_results(
mask_pred, _bboxes, det_labels, self.test_cfg, ori_shape,
scale_factor, rescale)
return segm_result
# TODO: Currently not supported
def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
"""Test for mask head with test time augmentation."""
if det_bboxes.shape[0] == 0:
segm_result = [[] for _ in range(self.mask_head.num_classes)]
else:
aug_masks = []
for x, img_meta in zip(feats, img_metas):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
flip_direction = img_meta[0]['flip_direction']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip, flip_direction)
mask_rois = bbox2roi([_bboxes])
mask_results = self._mask_forward(x, mask_rois)
# convert to numpy array to save memory
aug_masks.append(
mask_results['mask_pred'].sigmoid().cpu().numpy())
merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg)
ori_shape = img_metas[0][0]['ori_shape']
scale_factor = det_bboxes.new_ones(4)
segm_result = self.mask_head.get_results(
merged_masks,
det_bboxes,
det_labels,
self.test_cfg,
ori_shape,
scale_factor=scale_factor,
rescale=False)
return segm_result
|