wondervictor's picture
update lfs
f5fdf51
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor
from torchvision.ops import batched_nms
_XYWH2XYXY = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0],
[-0.5, 0.0, 0.5, 0.0], [0.0, -0.5, 0.0, 0.5]],
dtype=torch.float32)
def sort_nms_index(nms_index, scores, batch_size, keep_top_k=-1):
"""
first sort the nms_index by batch, and then sort by score in every image result, final apply keep_top_k strategy. In the process, we can also get the number of detections for each image: num_dets
"""
# first sort by batch index to make sure that the same batch index is together
device = nms_index.device
nms_index_indices = torch.argsort(nms_index[:, 0], dim=0).to(device)
nms_index = nms_index[nms_index_indices]
scores = scores[nms_index[:, 0], nms_index[:, 1], nms_index[:, 2]]
batch_inds = nms_index[:, 0]
# Get the number of detections for each image
num_dets = torch.bincount(batch_inds,minlength=batch_size).to(device)
# Calculate the sum from front to back
cumulative_sum = torch.cumsum(num_dets, dim=0).to(device)
# add initial value 0
cumulative_sum = torch.cat((torch.tensor([0]).to(device), cumulative_sum))
for i in range(len(num_dets)):
start = cumulative_sum[i]
end = cumulative_sum[i + 1]
# sort by score in every batch
block_idx = torch.argsort(scores[start:end], descending=True).to(device)
nms_index[start:end] = nms_index[start:end][block_idx]
if keep_top_k > 0 and end - start > keep_top_k:
# delete lines from start+keep_top_k to end to keep only top k
nms_index = torch.cat(
(nms_index[: start + keep_top_k], nms_index[end:]), dim=0
)
num_dets[i] -= end - start - keep_top_k
cumulative_sum -= end - start - keep_top_k
return nms_index, num_dets
def select_nms_index(
scores: Tensor,
boxes: Tensor,
nms_index: Tensor,
batch_size: int,
keep_top_k: int = -1,
):
if nms_index.numel() == 0:
return torch.empty(0), torch.empty(0, 4), torch.empty(0), torch.empty(0)
nms_index, num_dets = sort_nms_index(nms_index, scores, batch_size, keep_top_k)
batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1]
box_inds = nms_index[:, 2]
# according to the nms_index to get the scores,boxes and labels
batched_scores = scores[batch_inds, cls_inds, box_inds]
batched_dets = boxes[batch_inds, box_inds, ...]
batched_labels = cls_inds
return num_dets, batched_dets, batched_scores, batched_labels
def construct_indice(batch_idx, select_bbox_idxs, class_idxs, original_idxs):
num_bbox = len(select_bbox_idxs)
class_idxs = class_idxs[select_bbox_idxs]
indice = torch.zeros((num_bbox, 3), dtype=torch.int32).to(select_bbox_idxs.device)
# batch_idx
indice[:, 0] = batch_idx
# class_idxs
indice[:, 1] = class_idxs
# select_bbox_idxs
indice[:, 2] = original_idxs[select_bbox_idxs]
return indice
def filter_max_boxes_per_class(
select_bbox_idxs, class_idxs, max_output_boxes_per_class
):
class_counts = {} # used to track the count of each class
filtered_select_bbox_idxs = []
filtered_max_class_idxs = []
for bbox_idx, class_idx in zip(select_bbox_idxs, class_idxs):
class_count = class_counts.get(
class_idx.item(), 0
) # Get the count of the current class, or return 0 if it does not exist
if class_count < max_output_boxes_per_class:
filtered_select_bbox_idxs.append(bbox_idx)
filtered_max_class_idxs.append(class_idx)
class_counts[class_idx.item()] = class_count + 1
return torch.tensor(filtered_select_bbox_idxs), torch.tensor(
filtered_max_class_idxs
)
class ONNXNMSop(torch.autograd.Function):
@staticmethod
def forward(
ctx,
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: Tensor = torch.tensor([100]),
iou_threshold: Tensor = torch.tensor([0.5]),
score_threshold: Tensor = torch.tensor([0.05])
) -> Tensor:
"""
Non-Maximum Suppression (NMS) implementation.
Args:
boxes (Tensor): Bounding boxes of shape (batch_size, num_boxes, 4).
scores (Tensor): Confidence scores of shape (batch_size, num_classes, num_boxes).
max_output_boxes_per_class (Tensor): Maximum number of output boxes per class.
iou_threshold (Tensor): IoU threshold for NMS.
score_threshold (Tensor): Confidence score threshold.
Returns:
Tensor: Selected indices of shape (num_det, 3).first value is batch index, second value is class index, third value is box index
"""
device = boxes.device
batch_size, num_classes, num_boxes = scores.shape
selected_indices = []
for batch_idx in range(batch_size):
boxes_per_image = boxes[batch_idx]
scores_per_image = scores[batch_idx]
# If no boxes in this image, continue to the next image
if boxes_per_image.numel() == 0:
continue
# for one box, only exist one class,so use torch.max to get the max score and class index
scores_per_image, class_idxs = torch.max(scores_per_image, dim=0)
# Apply score threshold before batched_nms bacause nms operation is time expensive
keep_idxs = scores_per_image > score_threshold
if not torch.any(keep_idxs):
# If no boxes left after applying score threshold, continue to the next image
continue
boxes_per_image = boxes_per_image[keep_idxs]
scores_per_image = scores_per_image[keep_idxs]
class_idxs = class_idxs[keep_idxs]
# The purpose of original_idxs is we want to return the indexs to the original input data instead of the filtered.
original_idxs = torch.arange(num_boxes, device=device)[keep_idxs]
# reference: https://pytorch.org/vision/main/generated/torchvision.ops.batched_nms.html
select_bbox_idxs = batched_nms(
boxes_per_image, scores_per_image, class_idxs, iou_threshold
)
if (
select_bbox_idxs.shape[0] > max_output_boxes_per_class
): # If the boxes detected by all classes together are less than max_output_boxes_per_class, then there is no need to filter
select_bbox_idxs, _ = filter_max_boxes_per_class(
select_bbox_idxs,
class_idxs[select_bbox_idxs],
max_output_boxes_per_class,
)
selected_indice = construct_indice(
batch_idx, select_bbox_idxs, class_idxs, original_idxs
)
selected_indices.append(selected_indice)
if len(selected_indices) == 0:
return torch.tensor([], device=device)
selected_indices = torch.cat(selected_indices, dim=0)
return selected_indices
@staticmethod
def symbolic(
g,
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: Tensor = torch.tensor([100]),
iou_threshold: Tensor = torch.tensor([0.5]),
score_threshold: Tensor = torch.tensor([0.05]),
):
return g.op(
'NonMaxSuppression',
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
outputs=1)
def onnx_nms(
boxes: torch.Tensor,
scores: torch.Tensor,
max_output_boxes_per_class: int = 100,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = 100,
box_coding: int = 0,
):
max_output_boxes_per_class = torch.tensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold]).to(boxes.device)
score_threshold = torch.tensor([score_threshold]).to(boxes.device)
batch_size, _, _ = scores.shape
if box_coding == 1:
boxes = boxes @ (_XYWH2XYXY.to(boxes.device))
scores = scores.transpose(1, 2).contiguous()
selected_indices = ONNXNMSop.apply(boxes, scores,
max_output_boxes_per_class,
iou_threshold, score_threshold)
num_dets, batched_dets, batched_scores, batched_labels = select_nms_index(
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
return num_dets, batched_dets, batched_scores, batched_labels.to(
torch.int32)