Spaces:
Runtime error
Runtime error
File size: 2,198 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# Copyright (c) OpenMMLab. All rights reserved.
"""copy from
https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py."""
import torch
from mmengine.structures import InstanceData
from mmdet.registry import TASK_UTILS
from ..assigners import AssignResult
from .base_sampler import BaseSampler
from .mask_sampling_result import MaskSamplingResult
@TASK_UTILS.register_module()
class MaskPseudoSampler(BaseSampler):
"""A pseudo sampler that does not do sampling actually."""
def __init__(self, **kwargs):
pass
def _sample_pos(self, **kwargs):
"""Sample positive samples."""
raise NotImplementedError
def _sample_neg(self, **kwargs):
"""Sample negative samples."""
raise NotImplementedError
def sample(self, assign_result: AssignResult, pred_instances: InstanceData,
gt_instances: InstanceData, *args, **kwargs):
"""Directly returns the positive and negative indices of samples.
Args:
assign_result (:obj:`AssignResult`): Mask assigning results.
pred_instances (:obj:`InstanceData`): Instances of model
predictions. It includes ``scores`` and ``masks`` predicted
by the model.
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``labels`` and ``masks``
attributes.
Returns:
:obj:`SamplingResult`: sampler results
"""
pred_masks = pred_instances.masks
gt_masks = gt_instances.masks
pos_inds = torch.nonzero(
assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(
assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
gt_flags = pred_masks.new_zeros(pred_masks.shape[0], dtype=torch.uint8)
sampling_result = MaskSamplingResult(
pos_inds=pos_inds,
neg_inds=neg_inds,
masks=pred_masks,
gt_masks=gt_masks,
assign_result=assign_result,
gt_flags=gt_flags,
avg_factor_with_neg=False)
return sampling_result
|