File size: 12,185 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
# Copyright (c) Facebook, Inc. and its affiliates.
from typing import List
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, cat, get_norm
from detectron2.layers.wrappers import move_device_like
from detectron2.structures import Instances
from detectron2.utils.events import get_event_storage
from detectron2.utils.registry import Registry
__all__ = [
"BaseMaskRCNNHead",
"MaskRCNNConvUpsampleHead",
"build_mask_head",
"ROI_MASK_HEAD_REGISTRY",
]
ROI_MASK_HEAD_REGISTRY = Registry("ROI_MASK_HEAD")
ROI_MASK_HEAD_REGISTRY.__doc__ = """
Registry for mask heads, which predicts instance masks given
per-region features.
The registered object will be called with `obj(cfg, input_shape)`.
"""
@torch.jit.unused
def mask_rcnn_loss(pred_mask_logits: torch.Tensor, instances: List[Instances], vis_period: int = 0):
"""
Compute the mask prediction loss defined in the Mask R-CNN paper.
Args:
pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask)
for class-specific or class-agnostic, where B is the total number of predicted masks
in all images, C is the number of foreground classes, and Hmask, Wmask are the height
and width of the mask predictions. The values are logits.
instances (list[Instances]): A list of N Instances, where N is the number of images
in the batch. These instances are in 1:1
correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask,
...) associated with each instance are stored in fields.
vis_period (int): the period (in steps) to dump visualization.
Returns:
mask_loss (Tensor): A scalar tensor containing the loss.
"""
cls_agnostic_mask = pred_mask_logits.size(1) == 1
total_num_masks = pred_mask_logits.size(0)
mask_side_len = pred_mask_logits.size(2)
assert pred_mask_logits.size(2) == pred_mask_logits.size(3), "Mask prediction must be square!"
gt_classes = []
gt_masks = []
for instances_per_image in instances:
if len(instances_per_image) == 0:
continue
if not cls_agnostic_mask:
gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64)
gt_classes.append(gt_classes_per_image)
gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize(
instances_per_image.proposal_boxes.tensor, mask_side_len
).to(device=pred_mask_logits.device)
# A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len
gt_masks.append(gt_masks_per_image)
if len(gt_masks) == 0:
return pred_mask_logits.sum() * 0
gt_masks = cat(gt_masks, dim=0)
if cls_agnostic_mask:
pred_mask_logits = pred_mask_logits[:, 0]
else:
indices = torch.arange(total_num_masks)
gt_classes = cat(gt_classes, dim=0)
pred_mask_logits = pred_mask_logits[indices, gt_classes]
if gt_masks.dtype == torch.bool:
gt_masks_bool = gt_masks
else:
# Here we allow gt_masks to be float as well (depend on the implementation of rasterize())
gt_masks_bool = gt_masks > 0.5
gt_masks = gt_masks.to(dtype=torch.float32)
# Log the training accuracy (using gt classes and sigmoid(0.0) == 0.5 threshold)
mask_incorrect = (pred_mask_logits > 0.0) != gt_masks_bool
mask_accuracy = 1 - (mask_incorrect.sum().item() / max(mask_incorrect.numel(), 1.0))
num_positive = gt_masks_bool.sum().item()
false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max(
gt_masks_bool.numel() - num_positive, 1.0
)
false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max(num_positive, 1.0)
storage = get_event_storage()
storage.put_scalar("mask_rcnn/accuracy", mask_accuracy)
storage.put_scalar("mask_rcnn/false_positive", false_positive)
storage.put_scalar("mask_rcnn/false_negative", false_negative)
if vis_period > 0 and storage.iter % vis_period == 0:
pred_masks = pred_mask_logits.sigmoid()
vis_masks = torch.cat([pred_masks, gt_masks], axis=2)
name = "Left: mask prediction; Right: mask GT"
for idx, vis_mask in enumerate(vis_masks):
vis_mask = torch.stack([vis_mask] * 3, axis=0)
storage.put_image(name + f" ({idx})", vis_mask)
mask_loss = F.binary_cross_entropy_with_logits(pred_mask_logits, gt_masks, reduction="mean")
return mask_loss
def mask_rcnn_inference(pred_mask_logits: torch.Tensor, pred_instances: List[Instances]):
"""
Convert pred_mask_logits to estimated foreground probability masks while also
extracting only the masks for the predicted classes in pred_instances. For each
predicted box, the mask of the same class is attached to the instance by adding a
new "pred_masks" field to pred_instances.
Args:
pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask)
for class-specific or class-agnostic, where B is the total number of predicted masks
in all images, C is the number of foreground classes, and Hmask, Wmask are the height
and width of the mask predictions. The values are logits.
pred_instances (list[Instances]): A list of N Instances, where N is the number of images
in the batch. Each Instances must have field "pred_classes".
Returns:
None. pred_instances will contain an extra "pred_masks" field storing a mask of size (Hmask,
Wmask) for predicted class. Note that the masks are returned as a soft (non-quantized)
masks the resolution predicted by the network; post-processing steps, such as resizing
the predicted masks to the original image resolution and/or binarizing them, is left
to the caller.
"""
cls_agnostic_mask = pred_mask_logits.size(1) == 1
if cls_agnostic_mask:
mask_probs_pred = pred_mask_logits.sigmoid()
else:
# Select masks corresponding to the predicted classes
num_masks = pred_mask_logits.shape[0]
class_pred = cat([i.pred_classes for i in pred_instances])
device = (
class_pred.device
if torch.jit.is_scripting()
else ("cpu" if torch.jit.is_tracing() else class_pred.device)
)
indices = move_device_like(torch.arange(num_masks, device=device), class_pred)
mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid()
# mask_probs_pred.shape: (B, 1, Hmask, Wmask)
num_boxes_per_image = [len(i) for i in pred_instances]
mask_probs_pred = mask_probs_pred.split(num_boxes_per_image, dim=0)
for prob, instances in zip(mask_probs_pred, pred_instances):
instances.pred_masks = prob # (1, Hmask, Wmask)
class BaseMaskRCNNHead(nn.Module):
"""
Implement the basic Mask R-CNN losses and inference logic described in :paper:`Mask R-CNN`
"""
@configurable
def __init__(self, *, loss_weight: float = 1.0, vis_period: int = 0):
"""
NOTE: this interface is experimental.
Args:
loss_weight (float): multiplier of the loss
vis_period (int): visualization period
"""
super().__init__()
self.vis_period = vis_period
self.loss_weight = loss_weight
@classmethod
def from_config(cls, cfg, input_shape):
return {"vis_period": cfg.VIS_PERIOD}
def forward(self, x, instances: List[Instances]):
"""
Args:
x: input region feature(s) provided by :class:`ROIHeads`.
instances (list[Instances]): contains the boxes & labels corresponding
to the input features.
Exact format is up to its caller to decide.
Typically, this is the foreground instances in training, with
"proposal_boxes" field and other gt annotations.
In inference, it contains boxes that are already predicted.
Returns:
A dict of losses in training. The predicted "instances" in inference.
"""
x = self.layers(x)
if self.training:
return {"loss_mask": mask_rcnn_loss(x, instances, self.vis_period) * self.loss_weight}
else:
mask_rcnn_inference(x, instances)
return instances
def layers(self, x):
"""
Neural network layers that makes predictions from input features.
"""
raise NotImplementedError
# To get torchscript support, we make the head a subclass of `nn.Sequential`.
# Therefore, to add new layers in this head class, please make sure they are
# added in the order they will be used in forward().
@ROI_MASK_HEAD_REGISTRY.register()
class MaskRCNNConvUpsampleHead(BaseMaskRCNNHead, nn.Sequential):
"""
A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`).
Predictions are made with a final 1x1 conv layer.
"""
@configurable
def __init__(self, input_shape: ShapeSpec, *, num_classes, conv_dims, conv_norm="", **kwargs):
"""
NOTE: this interface is experimental.
Args:
input_shape (ShapeSpec): shape of the input feature
num_classes (int): the number of foreground classes (i.e. background is not
included). 1 if using class agnostic prediction.
conv_dims (list[int]): a list of N>0 integers representing the output dimensions
of N-1 conv layers and the last upsample layer.
conv_norm (str or callable): normalization for the conv layers.
See :func:`detectron2.layers.get_norm` for supported types.
"""
super().__init__(**kwargs)
assert len(conv_dims) >= 1, "conv_dims have to be non-empty!"
self.conv_norm_relus = []
cur_channels = input_shape.channels
for k, conv_dim in enumerate(conv_dims[:-1]):
conv = Conv2d(
cur_channels,
conv_dim,
kernel_size=3,
stride=1,
padding=1,
bias=not conv_norm,
norm=get_norm(conv_norm, conv_dim),
activation=nn.ReLU(),
)
self.add_module("mask_fcn{}".format(k + 1), conv)
self.conv_norm_relus.append(conv)
cur_channels = conv_dim
self.deconv = ConvTranspose2d(
cur_channels, conv_dims[-1], kernel_size=2, stride=2, padding=0
)
self.add_module("deconv_relu", nn.ReLU())
cur_channels = conv_dims[-1]
self.predictor = Conv2d(cur_channels, num_classes, kernel_size=1, stride=1, padding=0)
for layer in self.conv_norm_relus + [self.deconv]:
weight_init.c2_msra_fill(layer)
# use normal distribution initialization for mask prediction layer
nn.init.normal_(self.predictor.weight, std=0.001)
if self.predictor.bias is not None:
nn.init.constant_(self.predictor.bias, 0)
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM
num_conv = cfg.MODEL.ROI_MASK_HEAD.NUM_CONV
ret.update(
conv_dims=[conv_dim] * (num_conv + 1), # +1 for ConvTranspose
conv_norm=cfg.MODEL.ROI_MASK_HEAD.NORM,
input_shape=input_shape,
)
if cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK:
ret["num_classes"] = 1
else:
ret["num_classes"] = cfg.MODEL.ROI_HEADS.NUM_CLASSES
return ret
def layers(self, x):
for layer in self:
x = layer(x)
return x
def build_mask_head(cfg, input_shape):
"""
Build a mask head defined by `cfg.MODEL.ROI_MASK_HEAD.NAME`.
"""
name = cfg.MODEL.ROI_MASK_HEAD.NAME
return ROI_MASK_HEAD_REGISTRY.get(name)(cfg, input_shape)
|