Spaces:
Runtime error
Runtime error
File size: 3,931 Bytes
3094730 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import mark
from torch import Tensor
def _efficient_nms(
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = 100,
box_coding: int = 0,
):
"""Wrapper for `efficient_nms` with TensorRT.
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
box_coding (int): Bounding boxes format for nms.
Defaults to 0 means [x, y, w, h].
Set to 1 means [x1, y1 ,x2, y2].
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
"""
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
_, det_boxes, det_scores, labels = TRTEfficientNMSop.apply(
boxes, scores, -1, box_coding, iou_threshold, keep_top_k, '1', 0,
score_threshold)
dets = torch.cat([det_boxes, det_scores.unsqueeze(2)], -1)
# retain shape info
batch_size = boxes.size(0)
dets_shape = dets.shape
label_shape = labels.shape
dets = dets.reshape([batch_size, *dets_shape[1:]])
labels = labels.reshape([batch_size, *label_shape[1:]])
return dets, labels
@mark('efficient_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels'])
def efficient_nms(*args, **kwargs):
"""Wrapper function for `_efficient_nms`."""
return _efficient_nms(*args, **kwargs)
class TRTEfficientNMSop(torch.autograd.Function):
"""Efficient NMS op for TensorRT."""
@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=0,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version='1',
score_activation=0,
score_threshold=0.25,
):
"""Forward function of TRTEfficientNMSop."""
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(
0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(
0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=0,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version='1',
score_activation=0,
score_threshold=0.25):
"""Symbolic function of TRTEfficientNMSop."""
out = g.op(
'TRT::EfficientNMS_TRT',
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums, boxes, scores, classes
|