|
|
|
"""copy from |
|
https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
from ..assigners import AssignResult |
|
from .sampling_result import SamplingResult |
|
|
|
|
|
class MaskSamplingResult(SamplingResult): |
|
"""Mask sampling result.""" |
|
|
|
def __init__(self, |
|
pos_inds: Tensor, |
|
neg_inds: Tensor, |
|
masks: Tensor, |
|
gt_masks: Tensor, |
|
assign_result: AssignResult, |
|
gt_flags: Tensor, |
|
avg_factor_with_neg: bool = True) -> None: |
|
self.pos_inds = pos_inds |
|
self.neg_inds = neg_inds |
|
self.num_pos = max(pos_inds.numel(), 1) |
|
self.num_neg = max(neg_inds.numel(), 1) |
|
self.avg_factor = self.num_pos + self.num_neg \ |
|
if avg_factor_with_neg else self.num_pos |
|
|
|
self.pos_masks = masks[pos_inds] |
|
self.neg_masks = masks[neg_inds] |
|
self.pos_is_gt = gt_flags[pos_inds] |
|
|
|
self.num_gts = gt_masks.shape[0] |
|
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 |
|
|
|
if gt_masks.numel() == 0: |
|
|
|
assert self.pos_assigned_gt_inds.numel() == 0 |
|
self.pos_gt_masks = torch.empty_like(gt_masks) |
|
else: |
|
self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] |
|
|
|
@property |
|
def masks(self) -> Tensor: |
|
"""torch.Tensor: concatenated positive and negative masks.""" |
|
return torch.cat([self.pos_masks, self.neg_masks]) |
|
|
|
def __nice__(self) -> str: |
|
data = self.info.copy() |
|
data['pos_masks'] = data.pop('pos_masks').shape |
|
data['neg_masks'] = data.pop('neg_masks').shape |
|
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] |
|
body = ' ' + ',\n '.join(parts) |
|
return '{\n' + body + '\n}' |
|
|
|
@property |
|
def info(self) -> dict: |
|
"""Returns a dictionary of info about the object.""" |
|
return { |
|
'pos_inds': self.pos_inds, |
|
'neg_inds': self.neg_inds, |
|
'pos_masks': self.pos_masks, |
|
'neg_masks': self.neg_masks, |
|
'pos_is_gt': self.pos_is_gt, |
|
'num_gts': self.num_gts, |
|
'pos_assigned_gt_inds': self.pos_assigned_gt_inds, |
|
} |
|
|