Spaces:
Runtime error
Runtime error
File size: 9,064 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.models.utils import multi_apply
from mmdet.utils import InstanceList, reduce_mean
from mmpl.registry import MODELS, TASK_UTILS
from mmengine.model import BaseModel
from einops import rearrange
from mmpl.utils import ConfigType, OptConfigType
@MODELS.register_module()
class BinarySemanticSegHead(BaseModel):
def __init__(
self,
num_classes=1,
align_corners=False,
loss_mask: ConfigType = dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=5.0),
loss_dice=None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: Optional[dict] = None):
super(BinarySemanticSegHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.align_corners = align_corners
self.test_cfg = test_cfg
self.train_cfg = train_cfg
if train_cfg:
self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
self.sampler = TASK_UTILS.build(
self.train_cfg['sampler'], default_args=dict(context=self))
self.num_points = self.train_cfg.get('num_points', 12544)
self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
self.importance_sample_ratio = self.train_cfg.get(
'importance_sample_ratio', 0.75)
self.loss_mask = MODELS.build(loss_mask)
if loss_dice is not None:
self.loss_dice = MODELS.build(loss_dice)
def forward(self, *args, **kwargs):
pass
return
def loss(self,
mask_preds: Tensor,
seg_labels: Tensor,
):
bs = mask_preds.size(0)
# dice loss
if hasattr(self, 'loss_dice'):
loss_dice = self.loss_dice(mask_preds, seg_labels, avg_factor=bs)
else:
loss_dice = torch.zeros([]).to(mask_preds.device)
# mask loss
# FocalLoss support input of shape (n, num_class)
h, w = mask_preds.shape[-2:]
# shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1)
mask_preds = mask_preds.reshape(-1, 1)
# shape (num_total_gts, h, w) -> (num_total_gts * h * w)
mask_targets = seg_labels.reshape(-1, 1)
# target is (1 - mask_targets) !!!
loss_mask = self.loss_mask(mask_preds, mask_targets, avg_factor=h * w)
loss_dict = dict()
loss_dict['loss_mask'] = loss_mask
loss_dict['loss_dice'] = loss_dice
return loss_dict
def get_targets(
self,
cls_scores_list: List[Tensor],
mask_preds_list: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
return_sampling_results: bool = False
) -> Tuple[List[Union[Tensor, int]]]:
"""Compute classification and mask targets for all images for a decoder
layer.
Args:
cls_scores_list (list[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape (num_queries,
cls_out_channels).
mask_preds_list (list[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape (num_queries, h, w).
batch_gt_instances (list[obj:`InstanceData`]): each contains
``labels`` and ``masks``.
batch_img_metas (list[dict]): List of image meta information.
return_sampling_results (bool): Whether to return the sampling
results. Defaults to False.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels of all images.\
Each with shape (num_queries, ).
- label_weights_list (list[Tensor]): Label weights\
of all images. Each with shape (num_queries, ).
- mask_targets_list (list[Tensor]): Mask targets of\
all images. Each with shape (num_queries, h, w).
- mask_weights_list (list[Tensor]): Mask weights of\
all images. Each with shape (num_queries, ).
- avg_factor (int): Average factor that is used to average\
the loss. When using sampling method, avg_factor is
usually the sum of positive and negative priors. When
using `MaskPseudoSampler`, `avg_factor` is usually equal
to the number of positive priors.
additional_returns: This function enables user-defined returns from
`self._get_targets_single`. These returns are currently refined
to properties at each feature map (i.e. having HxW dimension).
The results will be concatenated after the end.
"""
results = multi_apply(self._get_targets_single, cls_scores_list,
mask_preds_list, batch_gt_instances,
batch_img_metas)
(labels_list, label_weights_list, mask_targets_list, mask_weights_list,
pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
rest_results = list(results[7:])
avg_factor = sum(
[results.avg_factor for results in sampling_results_list])
res = (labels_list, label_weights_list, mask_targets_list,
mask_weights_list, avg_factor)
if return_sampling_results:
res = res + (sampling_results_list)
return res + tuple(rest_results)
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
gt_instances: InstanceData,
img_meta: dict) -> Tuple[Tensor]:
"""Compute classification and mask targets for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_instances (:obj:`InstanceData`): It contains ``labels`` and
``masks``.
img_meta (dict): Image informtation.
Returns:
tuple: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
shape (num_queries, ).
- label_weights (Tensor): Label weights of each image.
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image.
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image.
shape (num_queries, ).
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
- sampling_result (:obj:`SamplingResult`): Sampling results.
"""
gt_masks = gt_instances.masks
gt_labels = gt_instances.labels
target_shape = mask_pred.shape[-2:]
if gt_masks.shape[0] > 0:
gt_masks_downsampled = F.interpolate(
gt_masks.unsqueeze(1).float(), target_shape,
mode='nearest').squeeze(1).long()
else:
gt_masks_downsampled = gt_masks
pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
downsampled_gt_instances = InstanceData(
labels=gt_labels, masks=gt_masks_downsampled)
# assign and sample # assign_result is the 1-based
assign_result = self.assigner.assign(
pred_instances=pred_instances,
gt_instances=downsampled_gt_instances,
img_meta=img_meta)
sampling_result = self.sampler.sample(
assign_result=assign_result,
pred_instances=pred_instances,
gt_instances=gt_instances)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label target
# 第0类为背景
num_queries = pred_instances.scores.shape[0]
labels = gt_labels.new_full((num_queries, ),
0,
dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_labels.new_ones(num_queries)
# mask target
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
mask_weights = mask_pred.new_zeros((num_queries, ))
mask_weights[pos_inds] = 1.0
return (labels, label_weights, mask_targets, mask_weights, pos_inds,
neg_inds, sampling_result)
|