# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. from typing import List import torch from torchvision.ops import boxes as box_ops from torchvision.ops import nms # BC-compat def batched_nms( boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float ): """ Same as torchvision.ops.boxes.batched_nms, but safer. """ assert boxes.shape[-1] == 4 # TODO may need better strategy. # Investigate after having a fully-cuda NMS op. if len(boxes) < 40000: # fp16 does not have enough range for batched NMS return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold) result_mask = scores.new_zeros(scores.size(), dtype=torch.bool) for id in torch.jit.annotate(List[int], torch.unique(idxs).cpu().tolist()): mask = (idxs == id).nonzero().view(-1) keep = nms(boxes[mask], scores[mask], iou_threshold) result_mask[mask[keep]] = True keep = result_mask.nonzero().view(-1) keep = keep[scores[keep].argsort(descending=True)] return keep # Note: this function (nms_rotated) might be moved into # torchvision/ops/boxes.py in the future def nms_rotated(boxes, scores, iou_threshold): """ Performs non-maximum suppression (NMS) on the rotated boxes according to their intersection-over-union (IoU). Rotated NMS iteratively removes lower scoring rotated boxes which have an IoU greater than iou_threshold with another (higher scoring) rotated box. Note that RotatedBox (5, 3, 4, 2, -90) covers exactly the same region as RotatedBox (5, 3, 4, 2, 90) does, and their IoU will be 1. However, they can be representing completely different objects in certain tasks, e.g., OCR. As for the question of whether rotated-NMS should treat them as faraway boxes even though their IOU is 1, it depends on the application and/or ground truth annotation. As an extreme example, consider a single character v and the square box around it. If the angle is 0 degree, the object (text) would be read as 'v'; If the angle is 90 degrees, the object (text) would become '>'; If the angle is 180 degrees, the object (text) would become '^'; If the angle is 270/-90 degrees, the object (text) would become '<' All of these cases have IoU of 1 to each other, and rotated NMS that only uses IoU as criterion would only keep one of them with the highest score - which, practically, still makes sense in most cases because typically only one of theses orientations is the correct one. Also, it does not matter as much if the box is only used to classify the object (instead of transcribing them with a sequential OCR recognition model) later. On the other hand, when we use IoU to filter proposals that are close to the ground truth during training, we should definitely take the angle into account if we know the ground truth is labeled with the strictly correct orientation (as in, upside-down words are annotated with -180 degrees even though they can be covered with a 0/90/-90 degree box, etc.) The way the original dataset is annotated also matters. For example, if the dataset is a 4-point polygon dataset that does not enforce ordering of vertices/orientation, we can estimate a minimum rotated bounding box to this polygon, but there's no way we can tell the correct angle with 100% confidence (as shown above, there could be 4 different rotated boxes, with angles differed by 90 degrees to each other, covering the exactly same region). In that case we have to just use IoU to determine the box proximity (as many detection benchmarks (even for text) do) unless there're other assumptions we can make (like width is always larger than height, or the object is not rotated by more than 90 degrees CCW/CW, etc.) In summary, not considering angles in rotated NMS seems to be a good option for now, but we should be aware of its implications. Args: boxes (Tensor[N, 5]): Rotated boxes to perform NMS on. They are expected to be in (x_center, y_center, width, height, angle_degrees) format. scores (Tensor[N]): Scores for each one of the rotated boxes iou_threshold (float): Discards all overlapping rotated boxes with IoU < iou_threshold Returns: keep (Tensor): int64 tensor with the indices of the elements that have been kept by Rotated NMS, sorted in decreasing order of scores """ return torch.ops.detectron2.nms_rotated(boxes, scores, iou_threshold) # Note: this function (batched_nms_rotated) might be moved into # torchvision/ops/boxes.py in the future def batched_nms_rotated(boxes, scores, idxs, iou_threshold): """ Performs non-maximum suppression in a batched fashion. Each index value correspond to a category, and NMS will not be applied between elements of different categories. Args: boxes (Tensor[N, 5]): boxes where NMS will be performed. They are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format scores (Tensor[N]): scores for each one of the boxes idxs (Tensor[N]): indices of the categories for each one of the boxes. iou_threshold (float): discards all overlapping boxes with IoU < iou_threshold Returns: Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores """ assert boxes.shape[-1] == 5 if boxes.numel() == 0: return torch.empty((0,), dtype=torch.int64, device=boxes.device) boxes = boxes.float() # fp16 does not have enough range for batched NMS # Strategy: in order to perform NMS independently per class, # we add an offset to all the boxes. The offset is dependent # only on the class idx, and is large enough so that boxes # from different classes do not overlap # Note that batched_nms in torchvision/ops/boxes.py only uses max_coordinate, # which won't handle negative coordinates correctly. # Here by using min_coordinate we can make sure the negative coordinates are # correctly handled. max_coordinate = ( torch.max(boxes[:, 0], boxes[:, 1]) + torch.max(boxes[:, 2], boxes[:, 3]) / 2 ).max() min_coordinate = ( torch.min(boxes[:, 0], boxes[:, 1]) - torch.max(boxes[:, 2], boxes[:, 3]) / 2 ).min() offsets = idxs.to(boxes) * (max_coordinate - min_coordinate + 1) boxes_for_nms = boxes.clone() # avoid modifying the original values in boxes boxes_for_nms[:, :2] += offsets[:, None] keep = nms_rotated(boxes_for_nms, scores, iou_threshold) return keep