# Copyright (c) Facebook, Inc. and its affiliates. from detectron2.layers import batched_nms from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads from detectron2.modeling.roi_heads.roi_heads import Res5ROIHeads from detectron2.structures import Instances def merge_branch_instances(instances, num_branch, nms_thresh, topk_per_image): """ Merge detection results from different branches of TridentNet. Return detection results by applying non-maximum suppression (NMS) on bounding boxes and keep the unsuppressed boxes and other instances (e.g mask) if any. Args: instances (list[Instances]): A list of N * num_branch instances that store detection results. Contain N images and each image has num_branch instances. num_branch (int): Number of branches used for merging detection results for each image. nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. topk_per_image (int): The number of top scoring detections to return. Set < 0 to return all detections. Returns: results: (list[Instances]): A list of N instances, one for each image in the batch, that stores the topk most confidence detections after merging results from multiple branches. """ if num_branch == 1: return instances batch_size = len(instances) // num_branch results = [] for i in range(batch_size): instance = Instances.cat([instances[i + batch_size * j] for j in range(num_branch)]) # Apply per-class NMS keep = batched_nms( instance.pred_boxes.tensor, instance.scores, instance.pred_classes, nms_thresh ) keep = keep[:topk_per_image] result = instance[keep] results.append(result) return results @ROI_HEADS_REGISTRY.register() class TridentRes5ROIHeads(Res5ROIHeads): """ The TridentNet ROIHeads in a typical "C4" R-CNN model. See :class:`Res5ROIHeads`. """ def __init__(self, cfg, input_shape): super().__init__(cfg, input_shape) self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 def forward(self, images, features, proposals, targets=None): """ See :class:`Res5ROIHeads.forward`. """ num_branch = self.num_branch if self.training or not self.trident_fast else 1 all_targets = targets * num_branch if targets is not None else None pred_instances, losses = super().forward(images, features, proposals, all_targets) del images, all_targets, targets if self.training: return pred_instances, losses else: pred_instances = merge_branch_instances( pred_instances, num_branch, self.box_predictor.test_nms_thresh, self.box_predictor.test_topk_per_image, ) return pred_instances, {} @ROI_HEADS_REGISTRY.register() class TridentStandardROIHeads(StandardROIHeads): """ The `StandardROIHeads` for TridentNet. See :class:`StandardROIHeads`. """ def __init__(self, cfg, input_shape): super(TridentStandardROIHeads, self).__init__(cfg, input_shape) self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 def forward(self, images, features, proposals, targets=None): """ See :class:`Res5ROIHeads.forward`. """ # Use 1 branch if using trident_fast during inference. num_branch = self.num_branch if self.training or not self.trident_fast else 1 # Duplicate targets for all branches in TridentNet. all_targets = targets * num_branch if targets is not None else None pred_instances, losses = super().forward(images, features, proposals, all_targets) del images, all_targets, targets if self.training: return pred_instances, losses else: pred_instances = merge_branch_instances( pred_instances, num_branch, self.box_predictor.test_nms_thresh, self.box_predictor.test_topk_per_image, ) return pred_instances, {}