Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from torch import Tensor | |
from torchvision.ops import batched_nms | |
_XYWH2XYXY = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], | |
[-0.5, 0.0, 0.5, 0.0], [0.0, -0.5, 0.0, 0.5]], | |
dtype=torch.float32) | |
def sort_nms_index(nms_index, scores, batch_size, keep_top_k=-1): | |
""" | |
first sort the nms_index by batch, and then sort by score in every image result, final apply keep_top_k strategy. In the process, we can also get the number of detections for each image: num_dets | |
""" | |
# first sort by batch index to make sure that the same batch index is together | |
device = nms_index.device | |
nms_index_indices = torch.argsort(nms_index[:, 0], dim=0).to(device) | |
nms_index = nms_index[nms_index_indices] | |
scores = scores[nms_index[:, 0], nms_index[:, 1], nms_index[:, 2]] | |
batch_inds = nms_index[:, 0] | |
# Get the number of detections for each image | |
num_dets = torch.bincount(batch_inds,minlength=batch_size).to(device) | |
# Calculate the sum from front to back | |
cumulative_sum = torch.cumsum(num_dets, dim=0).to(device) | |
# add initial value 0 | |
cumulative_sum = torch.cat((torch.tensor([0]).to(device), cumulative_sum)) | |
for i in range(len(num_dets)): | |
start = cumulative_sum[i] | |
end = cumulative_sum[i + 1] | |
# sort by score in every batch | |
block_idx = torch.argsort(scores[start:end], descending=True).to(device) | |
nms_index[start:end] = nms_index[start:end][block_idx] | |
if keep_top_k > 0 and end - start > keep_top_k: | |
# delete lines from start+keep_top_k to end to keep only top k | |
nms_index = torch.cat( | |
(nms_index[: start + keep_top_k], nms_index[end:]), dim=0 | |
) | |
num_dets[i] -= end - start - keep_top_k | |
cumulative_sum -= end - start - keep_top_k | |
return nms_index, num_dets | |
def select_nms_index( | |
scores: Tensor, | |
boxes: Tensor, | |
nms_index: Tensor, | |
batch_size: int, | |
keep_top_k: int = -1, | |
): | |
if nms_index.numel() == 0: | |
return torch.empty(0), torch.empty(0, 4), torch.empty(0), torch.empty(0) | |
nms_index, num_dets = sort_nms_index(nms_index, scores, batch_size, keep_top_k) | |
batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1] | |
box_inds = nms_index[:, 2] | |
# according to the nms_index to get the scores,boxes and labels | |
batched_scores = scores[batch_inds, cls_inds, box_inds] | |
batched_dets = boxes[batch_inds, box_inds, ...] | |
batched_labels = cls_inds | |
return num_dets, batched_dets, batched_scores, batched_labels | |
def construct_indice(batch_idx, select_bbox_idxs, class_idxs, original_idxs): | |
num_bbox = len(select_bbox_idxs) | |
class_idxs = class_idxs[select_bbox_idxs] | |
indice = torch.zeros((num_bbox, 3), dtype=torch.int32).to(select_bbox_idxs.device) | |
# batch_idx | |
indice[:, 0] = batch_idx | |
# class_idxs | |
indice[:, 1] = class_idxs | |
# select_bbox_idxs | |
indice[:, 2] = original_idxs[select_bbox_idxs] | |
return indice | |
def filter_max_boxes_per_class( | |
select_bbox_idxs, class_idxs, max_output_boxes_per_class | |
): | |
class_counts = {} # used to track the count of each class | |
filtered_select_bbox_idxs = [] | |
filtered_max_class_idxs = [] | |
for bbox_idx, class_idx in zip(select_bbox_idxs, class_idxs): | |
class_count = class_counts.get( | |
class_idx.item(), 0 | |
) # Get the count of the current class, or return 0 if it does not exist | |
if class_count < max_output_boxes_per_class: | |
filtered_select_bbox_idxs.append(bbox_idx) | |
filtered_max_class_idxs.append(class_idx) | |
class_counts[class_idx.item()] = class_count + 1 | |
return torch.tensor(filtered_select_bbox_idxs), torch.tensor( | |
filtered_max_class_idxs | |
) | |
class ONNXNMSop(torch.autograd.Function): | |
def forward( | |
ctx, | |
boxes: Tensor, | |
scores: Tensor, | |
max_output_boxes_per_class: Tensor = torch.tensor([100]), | |
iou_threshold: Tensor = torch.tensor([0.5]), | |
score_threshold: Tensor = torch.tensor([0.05]) | |
) -> Tensor: | |
""" | |
Non-Maximum Suppression (NMS) implementation. | |
Args: | |
boxes (Tensor): Bounding boxes of shape (batch_size, num_boxes, 4). | |
scores (Tensor): Confidence scores of shape (batch_size, num_classes, num_boxes). | |
max_output_boxes_per_class (Tensor): Maximum number of output boxes per class. | |
iou_threshold (Tensor): IoU threshold for NMS. | |
score_threshold (Tensor): Confidence score threshold. | |
Returns: | |
Tensor: Selected indices of shape (num_det, 3).first value is batch index, second value is class index, third value is box index | |
""" | |
device = boxes.device | |
batch_size, num_classes, num_boxes = scores.shape | |
selected_indices = [] | |
for batch_idx in range(batch_size): | |
boxes_per_image = boxes[batch_idx] | |
scores_per_image = scores[batch_idx] | |
# If no boxes in this image, continue to the next image | |
if boxes_per_image.numel() == 0: | |
continue | |
# for one box, only exist one class,so use torch.max to get the max score and class index | |
scores_per_image, class_idxs = torch.max(scores_per_image, dim=0) | |
# Apply score threshold before batched_nms bacause nms operation is time expensive | |
keep_idxs = scores_per_image > score_threshold | |
if not torch.any(keep_idxs): | |
# If no boxes left after applying score threshold, continue to the next image | |
continue | |
boxes_per_image = boxes_per_image[keep_idxs] | |
scores_per_image = scores_per_image[keep_idxs] | |
class_idxs = class_idxs[keep_idxs] | |
# The purpose of original_idxs is we want to return the indexs to the original input data instead of the filtered. | |
original_idxs = torch.arange(num_boxes, device=device)[keep_idxs] | |
# reference: https://pytorch.org/vision/main/generated/torchvision.ops.batched_nms.html | |
select_bbox_idxs = batched_nms( | |
boxes_per_image, scores_per_image, class_idxs, iou_threshold | |
) | |
if ( | |
select_bbox_idxs.shape[0] > max_output_boxes_per_class | |
): # If the boxes detected by all classes together are less than max_output_boxes_per_class, then there is no need to filter | |
select_bbox_idxs, _ = filter_max_boxes_per_class( | |
select_bbox_idxs, | |
class_idxs[select_bbox_idxs], | |
max_output_boxes_per_class, | |
) | |
selected_indice = construct_indice( | |
batch_idx, select_bbox_idxs, class_idxs, original_idxs | |
) | |
selected_indices.append(selected_indice) | |
if len(selected_indices) == 0: | |
return torch.tensor([], device=device) | |
selected_indices = torch.cat(selected_indices, dim=0) | |
return selected_indices | |
def symbolic( | |
g, | |
boxes: Tensor, | |
scores: Tensor, | |
max_output_boxes_per_class: Tensor = torch.tensor([100]), | |
iou_threshold: Tensor = torch.tensor([0.5]), | |
score_threshold: Tensor = torch.tensor([0.05]), | |
): | |
return g.op( | |
'NonMaxSuppression', | |
boxes, | |
scores, | |
max_output_boxes_per_class, | |
iou_threshold, | |
score_threshold, | |
outputs=1) | |
def onnx_nms( | |
boxes: torch.Tensor, | |
scores: torch.Tensor, | |
max_output_boxes_per_class: int = 100, | |
iou_threshold: float = 0.5, | |
score_threshold: float = 0.05, | |
pre_top_k: int = -1, | |
keep_top_k: int = 100, | |
box_coding: int = 0, | |
): | |
max_output_boxes_per_class = torch.tensor([max_output_boxes_per_class]) | |
iou_threshold = torch.tensor([iou_threshold]).to(boxes.device) | |
score_threshold = torch.tensor([score_threshold]).to(boxes.device) | |
batch_size, _, _ = scores.shape | |
if box_coding == 1: | |
boxes = boxes @ (_XYWH2XYXY.to(boxes.device)) | |
scores = scores.transpose(1, 2).contiguous() | |
selected_indices = ONNXNMSop.apply(boxes, scores, | |
max_output_boxes_per_class, | |
iou_threshold, score_threshold) | |
num_dets, batched_dets, batched_scores, batched_labels = select_nms_index( | |
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k) | |
return num_dets, batched_dets, batched_scores, batched_labels.to( | |
torch.int32) | |