Spaces:
Runtime error
Runtime error
File size: 34,480 Bytes
a56642d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 |
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmcv.runner import force_fp32
from mmdet.core import (anchor_inside_flags, build_anchor_generator,
build_assigner, build_bbox_coder, build_sampler,
images_to_levels, multi_apply, multiclass_nms, unmap)
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
@HEADS.register_module()
class AnchorHead(BaseDenseHead, BBoxTestMixin):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels. Used in child classes.
anchor_generator (dict): Config dict for anchor generator
bbox_coder (dict): Config of bounding box coder.
reg_decoded_bbox (bool): If true, the regression loss would be
applied directly on decoded bounding boxes, converting both
the predicted boxes and regression targets to absolute
coordinates format. Default False. It should be `True` when
using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8, 16, 32],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
clip_border=True,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)),
reg_decoded_bbox=False,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
train_cfg=None,
test_cfg=None):
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
# TODO better way to determine whether sample or not
self.sampling = loss_cls['type'] not in [
'FocalLoss', 'GHMC', 'QualityFocalLoss'
]
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1
if self.cls_out_channels <= 0:
raise ValueError(f'num_classes={num_classes} is too small')
self.reg_decoded_bbox = reg_decoded_bbox
self.bbox_coder = build_bbox_coder(bbox_coder)
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if self.train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
# use PseudoSampler when sampling is False
if self.sampling and hasattr(self.train_cfg, 'sampler'):
sampler_cfg = self.train_cfg.sampler
else:
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.fp16_enabled = False
self.anchor_generator = build_anchor_generator(anchor_generator)
# usually the numbers of anchors for each level are the same
# except SSD detectors
self.num_anchors = self.anchor_generator.num_base_anchors[0]
self._init_layers()
def _init_layers(self):
"""Initialize layers of the head."""
self.conv_cls = nn.Conv2d(self.in_channels,
self.num_anchors * self.cls_out_channels, 1)
self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 4, 1)
def init_weights(self):
"""Initialize weights of the head."""
normal_init(self.conv_cls, std=0.01)
normal_init(self.conv_reg, std=0.01)
def forward_single(self, x):
"""Forward feature of a single scale level.
Args:
x (Tensor): Features of a single scale level.
Returns:
tuple:
cls_score (Tensor): Cls scores for a single scale level \
the channels number is num_anchors * num_classes.
bbox_pred (Tensor): Box energies / deltas for a single scale \
level, the channels number is num_anchors * 4.
"""
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
return cls_score, bbox_pred
def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of classification scores and bbox prediction.
- cls_scores (list[Tensor]): Classification scores for all \
scale levels, each is a 4D-tensor, the channels number \
is num_anchors * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for all \
scale levels, each is a 4D-tensor, the channels number \
is num_anchors * 4.
"""
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
device (torch.device | str): Device for returned tensors
Returns:
tuple:
anchor_list (list[Tensor]): Anchors of each image.
valid_flag_list (list[Tensor]): Valid flags of each image.
"""
num_imgs = len(img_metas)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = self.anchor_generator.valid_flags(
featmap_sizes, img_meta['pad_shape'], device)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def _get_targets_single(self,
flat_anchors,
valid_flags,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
label_channels=1,
unmap_outputs=True):
"""Compute regression and classification targets for anchors in a
single image.
Args:
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors ,4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
img_meta (dict): Meta info of the image.
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple:
labels_list (list[Tensor]): Labels of each level
label_weights_list (list[Tensor]): Label weights of each level
bbox_targets_list (list[Tensor]): BBox targets of each level
bbox_weights_list (list[Tensor]): BBox weights of each level
num_total_pos (int): Number of positive samples in all images
num_total_neg (int): Number of negative samples in all images
"""
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 7
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
assign_result = self.assigner.assign(
anchors, gt_bboxes, gt_bboxes_ignore,
None if self.sampling else gt_labels)
sampling_result = self.sampler.sample(assign_result, anchors,
gt_bboxes)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = anchors.new_full((num_valid_anchors, ),
self.num_classes,
dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
if not self.reg_decoded_bbox:
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
else:
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels is None:
# Only rpn gives gt_labels as None
# Foreground is the first class since v2.5.0
labels[pos_inds] = 0
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
labels = unmap(
labels, num_total_anchors, inside_flags,
fill=self.num_classes) # fill bg label
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds, sampling_result)
def get_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True,
return_sampling_results=False):
"""Compute regression and classification targets for anchors in
multiple images.
Args:
anchor_list (list[list[Tensor]]): Multi level anchors of each
image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, 4).
valid_flag_list (list[list[Tensor]]): Multi level valid flags of
each image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, )
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
ignored.
gt_labels_list (list[Tensor]): Ground truth labels of each box.
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: Usually returns a tuple containing learning targets.
- labels_list (list[Tensor]): Labels of each level.
- label_weights_list (list[Tensor]): Label weights of each \
level.
- bbox_targets_list (list[Tensor]): BBox targets of each level.
- bbox_weights_list (list[Tensor]): BBox weights of each level.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
additional_returns: This function enables user-defined returns from
`self._get_targets_single`. These returns are currently refined
to properties at each feature map (i.e. having HxW dimension).
The results will be concatenated after the end
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors to a single tensor
concat_anchor_list = []
concat_valid_flag_list = []
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
concat_anchor_list.append(torch.cat(anchor_list[i]))
concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
results = multi_apply(
self._get_targets_single,
concat_anchor_list,
concat_valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
label_channels=label_channels,
unmap_outputs=unmap_outputs)
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
rest_results = list(results[7:]) # user-added return values
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_anchors)
res = (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg)
if return_sampling_results:
res = res + (sampling_results_list, )
for i, r in enumerate(rest_results): # user-added return values
rest_results[i] = images_to_levels(r, num_level_anchors)
return res + tuple(rest_results)
def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples):
"""Compute loss of a single scale level.
Args:
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_pred (Tensor): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor wight
shape (N, num_total_anchors, 4).
bbox_weights (Tensor): BBox regression loss weights of each anchor
with shape (N, num_total_anchors, 4).
num_total_samples (int): If sampling, num total samples equal to
the number of total anchors; Otherwise, it is the number of
positive anchors.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
if self.reg_decoded_bbox:
# When the regression loss (e.g. `IouLoss`, `GIouLoss`)
# is applied directly on the decoded bounding boxes, it
# decodes the already encoded coordinates to absolute format.
anchors = anchors.reshape(-1, 4)
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor=num_total_samples)
return loss_cls, loss_bbox
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss. Default: None
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos + num_total_neg if self.sampling else num_total_pos)
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
concat_anchor_list = []
for i in range(len(anchor_list)):
concat_anchor_list.append(torch.cat(anchor_list[i]))
all_anchor_list = images_to_levels(concat_anchor_list,
num_level_anchors)
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
all_anchor_list,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg=None,
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each level in the
feature pyramid, has shape
(N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each
level in the feature pyramid, has shape
(N, num_anchors * 4, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is an (n, 5) tensor, where 5 represent
(tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
The shape of the second tensor in the tuple is (n,), and
each element represents the class label of the corresponding
box.
Example:
>>> import mmcv
>>> self = AnchorHead(
>>> num_classes=9,
>>> in_channels=1,
>>> anchor_generator=dict(
>>> type='AnchorGenerator',
>>> scales=[8],
>>> ratios=[0.5, 1.0, 2.0],
>>> strides=[4,]))
>>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}]
>>> cfg = mmcv.Config(dict(
>>> score_thr=0.00,
>>> nms=dict(type='nms', iou_thr=1.0),
>>> max_per_img=10))
>>> feat = torch.rand(1, 1, 3, 3)
>>> cls_score, bbox_pred = self.forward_single(feat)
>>> # note the input lists are over different levels, not images
>>> cls_scores, bbox_preds = [cls_score], [bbox_pred]
>>> result_list = self.get_bboxes(cls_scores, bbox_preds,
>>> img_metas, cfg)
>>> det_bboxes, det_labels = result_list[0]
>>> assert len(result_list) == 1
>>> assert det_bboxes.shape[1] == 5
>>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img
"""
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
if torch.onnx.is_in_onnx_export():
assert len(
img_metas
) == 1, 'Only support one input image while in exporting to ONNX'
img_shapes = img_metas[0]['img_shape_for_onnx']
else:
img_shapes = [
img_metas[i]['img_shape']
for i in range(cls_scores[0].shape[0])
]
scale_factors = [
img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0])
]
if with_nms:
# some heads don't support with_nms argument
result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds,
mlvl_anchors, img_shapes,
scale_factors, cfg, rescale)
else:
result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds,
mlvl_anchors, img_shapes,
scale_factors, cfg, rescale,
with_nms)
return result_list
def _get_bboxes(self,
mlvl_cls_scores,
mlvl_bbox_preds,
mlvl_anchors,
img_shapes,
scale_factors,
cfg,
rescale=False,
with_nms=True):
"""Transform outputs for a batch item into bbox predictions.
Args:
mlvl_cls_scores (list[Tensor]): Each element in the list is
the scores of bboxes of single level in the feature pyramid,
has shape (N, num_anchors * num_classes, H, W).
mlvl_bbox_preds (list[Tensor]): Each element in the list is the
bboxes predictions of single level in the feature pyramid,
has shape (N, num_anchors * 4, H, W).
mlvl_anchors (list[Tensor]): Each element in the list is
the anchors of single level in feature pyramid, has shape
(num_anchors, 4).
img_shapes (list[tuple[int]]): Each tuple in the list represent
the shape(height, width, 3) of single image in the batch.
scale_factors (list[ndarray]): Scale factor of the batch
image arange as list[(w_scale, h_scale, w_scale, h_scale)].
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is an (n, 5) tensor, where 5 represent
(tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
The shape of the second tensor in the tuple is (n,), and
each element represents the class label of the corresponding
box.
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(
mlvl_anchors)
batch_size = mlvl_cls_scores[0].shape[0]
# convert to tensor to keep tracing
nms_pre_tensor = torch.tensor(
cfg.get('nms_pre', -1),
device=mlvl_cls_scores[0].device,
dtype=torch.long)
mlvl_bboxes = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(mlvl_cls_scores,
mlvl_bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(batch_size, -1, 4)
anchors = anchors.expand_as(bbox_pred)
# Always keep topk op for dynamic input in onnx
if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
or scores.shape[-2] > nms_pre_tensor):
from torch import _shape_as_tensor
# keep shape as tensor and get k
num_anchor = _shape_as_tensor(scores)[-2].to(
nms_pre_tensor.device)
nms_pre = torch.where(nms_pre_tensor < num_anchor,
nms_pre_tensor, num_anchor)
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = scores.max(-1)
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
max_scores, _ = scores[..., :-1].max(-1)
_, topk_inds = max_scores.topk(nms_pre)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds)
anchors = anchors[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
bboxes = self.bbox_coder.decode(
anchors, bbox_pred, max_shape=img_shapes)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
if rescale:
batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
scale_factors).unsqueeze(1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
# Set max number of box to be feed into nms in deployment
deploy_nms_pre = cfg.get('deploy_nms_pre', -1)
if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export():
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = batch_mlvl_scores.max(-1)
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
max_scores, _ = batch_mlvl_scores[..., :-1].max(-1)
_, topk_inds = max_scores.topk(deploy_nms_pre)
batch_inds = torch.arange(batch_size).view(-1,
1).expand_as(topk_inds)
batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds]
batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds]
if self.use_sigmoid_cls:
# Add a dummy background class to the backend when using sigmoid
# remind that we set FG labels to [0, num_class-1] since mmdet v2.0
# BG cat_id: num_class
padding = batch_mlvl_scores.new_zeros(batch_size,
batch_mlvl_scores.shape[1],
1)
batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)
if with_nms:
det_results = []
for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes,
batch_mlvl_scores):
det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
det_results.append(tuple([det_bbox, det_label]))
else:
det_results = [
tuple(mlvl_bs)
for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores)
]
return det_results
def aug_test(self, feats, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
|