RSPrompter / mmpl /models /heads /semantic_seg_head.py
KyanChen's picture
Upload 159 files
1c3eb47
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.models.utils import multi_apply
from mmdet.utils import InstanceList, reduce_mean
from mmpl.registry import MODELS, TASK_UTILS
from mmengine.model import BaseModel
from einops import rearrange
from mmpl.utils import ConfigType, OptConfigType
@MODELS.register_module()
class BinarySemanticSegHead(BaseModel):
def __init__(
self,
num_classes=1,
align_corners=False,
loss_mask: ConfigType = dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=5.0),
loss_dice=None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: Optional[dict] = None):
super(BinarySemanticSegHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.align_corners = align_corners
self.test_cfg = test_cfg
self.train_cfg = train_cfg
if train_cfg:
self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
self.sampler = TASK_UTILS.build(
self.train_cfg['sampler'], default_args=dict(context=self))
self.num_points = self.train_cfg.get('num_points', 12544)
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
self.importance_sample_ratio = self.train_cfg.get(
'importance_sample_ratio', 0.75)
self.loss_mask = MODELS.build(loss_mask)
if loss_dice is not None:
self.loss_dice = MODELS.build(loss_dice)
def forward(self, *args, **kwargs):
pass
return
def loss(self,
mask_preds: Tensor,
seg_labels: Tensor,
):
bs = mask_preds.size(0)
# dice loss
if hasattr(self, 'loss_dice'):
loss_dice = self.loss_dice(mask_preds, seg_labels, avg_factor=bs)
else:
loss_dice = torch.zeros([]).to(mask_preds.device)
# mask loss
# FocalLoss support input of shape (n, num_class)
h, w = mask_preds.shape[-2:]
# shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1)
mask_preds = mask_preds.reshape(-1, 1)
# shape (num_total_gts, h, w) -> (num_total_gts * h * w)
mask_targets = seg_labels.reshape(-1, 1)
# target is (1 - mask_targets) !!!
loss_mask = self.loss_mask(mask_preds, mask_targets, avg_factor=h * w)
loss_dict = dict()
loss_dict['loss_mask'] = loss_mask
loss_dict['loss_dice'] = loss_dice
return loss_dict
def get_targets(
self,
cls_scores_list: List[Tensor],
mask_preds_list: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
return_sampling_results: bool = False
) -> Tuple[List[Union[Tensor, int]]]:
"""Compute classification and mask targets for all images for a decoder
layer.
Args:
cls_scores_list (list[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape (num_queries,
cls_out_channels).
mask_preds_list (list[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape (num_queries, h, w).
batch_gt_instances (list[obj:`InstanceData`]): each contains
``labels`` and ``masks``.
batch_img_metas (list[dict]): List of image meta information.
return_sampling_results (bool): Whether to return the sampling
results. Defaults to False.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels of all images.\
Each with shape (num_queries, ).
- label_weights_list (list[Tensor]): Label weights\
of all images. Each with shape (num_queries, ).
- mask_targets_list (list[Tensor]): Mask targets of\
all images. Each with shape (num_queries, h, w).
- mask_weights_list (list[Tensor]): Mask weights of\
all images. Each with shape (num_queries, ).
- avg_factor (int): Average factor that is used to average\
the loss. When using sampling method, avg_factor is
usually the sum of positive and negative priors. When
using `MaskPseudoSampler`, `avg_factor` is usually equal
to the number of positive priors.
additional_returns: This function enables user-defined returns from
`self._get_targets_single`. These returns are currently refined
to properties at each feature map (i.e. having HxW dimension).
The results will be concatenated after the end.
"""
results = multi_apply(self._get_targets_single, cls_scores_list,
mask_preds_list, batch_gt_instances,
batch_img_metas)
(labels_list, label_weights_list, mask_targets_list, mask_weights_list,
pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
rest_results = list(results[7:])
avg_factor = sum(
[results.avg_factor for results in sampling_results_list])
res = (labels_list, label_weights_list, mask_targets_list,
mask_weights_list, avg_factor)
if return_sampling_results:
res = res + (sampling_results_list)
return res + tuple(rest_results)
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
gt_instances: InstanceData,
img_meta: dict) -> Tuple[Tensor]:
"""Compute classification and mask targets for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_instances (:obj:`InstanceData`): It contains ``labels`` and
``masks``.
img_meta (dict): Image informtation.
Returns:
tuple: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
shape (num_queries, ).
- label_weights (Tensor): Label weights of each image.
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image.
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image.
shape (num_queries, ).
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
- sampling_result (:obj:`SamplingResult`): Sampling results.
"""
gt_masks = gt_instances.masks
gt_labels = gt_instances.labels
target_shape = mask_pred.shape[-2:]
if gt_masks.shape[0] > 0:
gt_masks_downsampled = F.interpolate(
gt_masks.unsqueeze(1).float(), target_shape,
mode='nearest').squeeze(1).long()
else:
gt_masks_downsampled = gt_masks
pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
downsampled_gt_instances = InstanceData(
labels=gt_labels, masks=gt_masks_downsampled)
# assign and sample # assign_result is the 1-based
assign_result = self.assigner.assign(
pred_instances=pred_instances,
gt_instances=downsampled_gt_instances,
img_meta=img_meta)
sampling_result = self.sampler.sample(
assign_result=assign_result,
pred_instances=pred_instances,
gt_instances=gt_instances)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label target
# 第0类为背景
num_queries = pred_instances.scores.shape[0]
labels = gt_labels.new_full((num_queries, ),
0,
dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_labels.new_ones(num_queries)
# mask target
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
mask_weights = mask_pred.new_zeros((num_queries, ))
mask_weights[pos_inds] = 1.0
return (labels, label_weights, mask_targets, mask_weights, pos_inds,
neg_inds, sampling_result)