|
|
|
from typing import Union |
|
|
|
import torch |
|
from mmengine.structures import InstanceData |
|
from numpy import ndarray |
|
from torch import Tensor |
|
|
|
from mmdet.registry import TASK_UTILS |
|
from ..assigners import AssignResult |
|
from .multi_instance_sampling_result import MultiInstanceSamplingResult |
|
from .random_sampler import RandomSampler |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class MultiInsRandomSampler(RandomSampler): |
|
"""Random sampler for multi instance. |
|
|
|
Note: |
|
Multi-instance means to predict multiple detection boxes with |
|
one proposal box. `AssignResult` may assign multiple gt boxes |
|
to each proposal box, in this case `RandomSampler` should be |
|
replaced by `MultiInsRandomSampler` |
|
""" |
|
|
|
def _sample_pos(self, assign_result: AssignResult, num_expected: int, |
|
**kwargs) -> Union[Tensor, ndarray]: |
|
"""Randomly sample some positive samples. |
|
|
|
Args: |
|
assign_result (:obj:`AssignResult`): Bbox assigning results. |
|
num_expected (int): The number of expected positive samples |
|
|
|
Returns: |
|
Tensor or ndarray: sampled indices. |
|
""" |
|
pos_inds = torch.nonzero( |
|
assign_result.labels[:, 0] > 0, as_tuple=False) |
|
if pos_inds.numel() != 0: |
|
pos_inds = pos_inds.squeeze(1) |
|
if pos_inds.numel() <= num_expected: |
|
return pos_inds |
|
else: |
|
return self.random_choice(pos_inds, num_expected) |
|
|
|
def _sample_neg(self, assign_result: AssignResult, num_expected: int, |
|
**kwargs) -> Union[Tensor, ndarray]: |
|
"""Randomly sample some negative samples. |
|
|
|
Args: |
|
assign_result (:obj:`AssignResult`): Bbox assigning results. |
|
num_expected (int): The number of expected positive samples |
|
|
|
Returns: |
|
Tensor or ndarray: sampled indices. |
|
""" |
|
neg_inds = torch.nonzero( |
|
assign_result.labels[:, 0] == 0, as_tuple=False) |
|
if neg_inds.numel() != 0: |
|
neg_inds = neg_inds.squeeze(1) |
|
if len(neg_inds) <= num_expected: |
|
return neg_inds |
|
else: |
|
return self.random_choice(neg_inds, num_expected) |
|
|
|
def sample(self, assign_result: AssignResult, pred_instances: InstanceData, |
|
gt_instances: InstanceData, |
|
**kwargs) -> MultiInstanceSamplingResult: |
|
"""Sample positive and negative bboxes. |
|
|
|
Args: |
|
assign_result (:obj:`AssignResult`): Assigning results from |
|
MultiInstanceAssigner. |
|
pred_instances (:obj:`InstanceData`): Instances of model |
|
predictions. It includes ``priors``, and the priors can |
|
be anchors or points, or the bboxes predicted by the |
|
previous stage, has shape (n, 4). The bboxes predicted by |
|
the current model or stage will be named ``bboxes``, |
|
``labels``, and ``scores``, the same as the ``InstanceData`` |
|
in other places. |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It usually includes ``bboxes``, with shape (k, 4), |
|
and ``labels``, with shape (k, ). |
|
|
|
Returns: |
|
:obj:`MultiInstanceSamplingResult`: Sampling result. |
|
""" |
|
|
|
assert 'batch_gt_instances_ignore' in kwargs, \ |
|
'batch_gt_instances_ignore is necessary for MultiInsRandomSampler' |
|
|
|
gt_bboxes = gt_instances.bboxes |
|
ignore_bboxes = kwargs['batch_gt_instances_ignore'].bboxes |
|
gt_and_ignore_bboxes = torch.cat([gt_bboxes, ignore_bboxes], dim=0) |
|
priors = pred_instances.priors |
|
if len(priors.shape) < 2: |
|
priors = priors[None, :] |
|
priors = priors[:, :4] |
|
|
|
gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8) |
|
priors = torch.cat([priors, gt_and_ignore_bboxes], dim=0) |
|
gt_ones = priors.new_ones( |
|
gt_and_ignore_bboxes.shape[0], dtype=torch.uint8) |
|
gt_flags = torch.cat([gt_flags, gt_ones]) |
|
|
|
num_expected_pos = int(self.num * self.pos_fraction) |
|
pos_inds = self.pos_sampler._sample_pos(assign_result, |
|
num_expected_pos) |
|
|
|
|
|
pos_inds = pos_inds.unique() |
|
num_sampled_pos = pos_inds.numel() |
|
num_expected_neg = self.num - num_sampled_pos |
|
if self.neg_pos_ub >= 0: |
|
_pos = max(1, num_sampled_pos) |
|
neg_upper_bound = int(self.neg_pos_ub * _pos) |
|
if num_expected_neg > neg_upper_bound: |
|
num_expected_neg = neg_upper_bound |
|
neg_inds = self.neg_sampler._sample_neg(assign_result, |
|
num_expected_neg) |
|
neg_inds = neg_inds.unique() |
|
|
|
sampling_result = MultiInstanceSamplingResult( |
|
pos_inds=pos_inds, |
|
neg_inds=neg_inds, |
|
priors=priors, |
|
gt_and_ignore_bboxes=gt_and_ignore_bboxes, |
|
assign_result=assign_result, |
|
gt_flags=gt_flags) |
|
return sampling_result |
|
|