Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Facebook, Inc. and its affiliates. | |
import numpy as np | |
import fvcore.nn.weight_init as weight_init | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from detectron2.layers import ShapeSpec, cat | |
from detectron2.utils.events import get_event_storage | |
from detectron2.utils.registry import Registry | |
POINT_HEAD_REGISTRY = Registry("POINT_HEAD") | |
POINT_HEAD_REGISTRY.__doc__ = """ | |
Registry for point heads, which makes prediction for a given set of per-point features. | |
The registered object will be called with `obj(cfg, input_shape)`. | |
""" | |
def roi_mask_point_loss(mask_logits, instances, point_labels): | |
""" | |
Compute the point-based loss for instance segmentation mask predictions | |
given point-wise mask prediction and its corresponding point-wise labels. | |
Args: | |
mask_logits (Tensor): A tensor of shape (R, C, P) or (R, 1, P) for class-specific or | |
class-agnostic, where R is the total number of predicted masks in all images, C is the | |
number of foreground classes, and P is the number of points sampled for each mask. | |
The values are logits. | |
instances (list[Instances]): A list of N Instances, where N is the number of images | |
in the batch. These instances are in 1:1 correspondence with the `mask_logits`. So, i_th | |
elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R. | |
The ground-truth labels (class, box, mask, ...) associated with each instance are stored | |
in fields. | |
point_labels (Tensor): A tensor of shape (R, P), where R is the total number of | |
predicted masks and P is the number of points for each mask. | |
Labels with value of -1 will be ignored. | |
Returns: | |
point_loss (Tensor): A scalar tensor containing the loss. | |
""" | |
with torch.no_grad(): | |
cls_agnostic_mask = mask_logits.size(1) == 1 | |
total_num_masks = mask_logits.size(0) | |
gt_classes = [] | |
for instances_per_image in instances: | |
if len(instances_per_image) == 0: | |
continue | |
if not cls_agnostic_mask: | |
gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) | |
gt_classes.append(gt_classes_per_image) | |
gt_mask_logits = point_labels | |
point_ignores = point_labels == -1 | |
if gt_mask_logits.shape[0] == 0: | |
return mask_logits.sum() * 0 | |
assert gt_mask_logits.numel() > 0, gt_mask_logits.shape | |
if cls_agnostic_mask: | |
mask_logits = mask_logits[:, 0] | |
else: | |
indices = torch.arange(total_num_masks) | |
gt_classes = cat(gt_classes, dim=0) | |
mask_logits = mask_logits[indices, gt_classes] | |
# Log the training accuracy (using gt classes and 0.0 threshold for the logits) | |
mask_accurate = (mask_logits > 0.0) == gt_mask_logits.to(dtype=torch.uint8) | |
mask_accurate = mask_accurate[~point_ignores] | |
mask_accuracy = mask_accurate.nonzero().size(0) / max(mask_accurate.numel(), 1.0) | |
get_event_storage().put_scalar("point/accuracy", mask_accuracy) | |
point_loss = F.binary_cross_entropy_with_logits( | |
mask_logits, gt_mask_logits.to(dtype=torch.float32), weight=~point_ignores, reduction="mean" | |
) | |
return point_loss | |
class StandardPointHead(nn.Module): | |
""" | |
A point head multi-layer perceptron which we model with conv1d layers with kernel 1. The head | |
takes both fine-grained and coarse prediction features as its input. | |
""" | |
def __init__(self, cfg, input_shape: ShapeSpec): | |
""" | |
The following attributes are parsed from config: | |
fc_dim: the output dimension of each FC layers | |
num_fc: the number of FC layers | |
coarse_pred_each_layer: if True, coarse prediction features are concatenated to each | |
layer's input | |
""" | |
super(StandardPointHead, self).__init__() | |
# fmt: off | |
num_classes = cfg.MODEL.POINT_HEAD.NUM_CLASSES | |
fc_dim = cfg.MODEL.POINT_HEAD.FC_DIM | |
num_fc = cfg.MODEL.POINT_HEAD.NUM_FC | |
cls_agnostic_mask = cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK | |
self.coarse_pred_each_layer = cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER | |
input_channels = input_shape.channels | |
# fmt: on | |
fc_dim_in = input_channels + num_classes | |
self.fc_layers = [] | |
for k in range(num_fc): | |
fc = nn.Conv1d(fc_dim_in, fc_dim, kernel_size=1, stride=1, padding=0, bias=True) | |
self.add_module("fc{}".format(k + 1), fc) | |
self.fc_layers.append(fc) | |
fc_dim_in = fc_dim | |
fc_dim_in += num_classes if self.coarse_pred_each_layer else 0 | |
num_mask_classes = 1 if cls_agnostic_mask else num_classes | |
self.predictor = nn.Conv1d(fc_dim_in, num_mask_classes, kernel_size=1, stride=1, padding=0) | |
for layer in self.fc_layers: | |
weight_init.c2_msra_fill(layer) | |
# use normal distribution initialization for mask prediction layer | |
nn.init.normal_(self.predictor.weight, std=0.001) | |
if self.predictor.bias is not None: | |
nn.init.constant_(self.predictor.bias, 0) | |
def forward(self, fine_grained_features, coarse_features): | |
x = torch.cat((fine_grained_features, coarse_features), dim=1) | |
for layer in self.fc_layers: | |
x = F.relu(layer(x)) | |
if self.coarse_pred_each_layer: | |
x = cat((x, coarse_features), dim=1) | |
return self.predictor(x) | |
class ImplicitPointHead(nn.Module): | |
""" | |
A point head multi-layer perceptron which we model with conv1d layers with kernel 1. The head | |
takes both fine-grained features and instance-wise MLP parameters as its input. | |
""" | |
def __init__(self, cfg, input_shape: ShapeSpec): | |
""" | |
The following attributes are parsed from config: | |
channels: the output dimension of each FC layers | |
num_layers: the number of FC layers (including the final prediction layer) | |
image_feature_enabled: if True, fine-grained image-level features are used | |
positional_encoding_enabled: if True, positional encoding is used | |
""" | |
super(ImplicitPointHead, self).__init__() | |
# fmt: off | |
self.num_layers = cfg.MODEL.POINT_HEAD.NUM_FC + 1 | |
self.channels = cfg.MODEL.POINT_HEAD.FC_DIM | |
self.image_feature_enabled = cfg.MODEL.IMPLICIT_POINTREND.IMAGE_FEATURE_ENABLED | |
self.positional_encoding_enabled = cfg.MODEL.IMPLICIT_POINTREND.POS_ENC_ENABLED | |
self.num_classes = ( | |
cfg.MODEL.POINT_HEAD.NUM_CLASSES if not cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK else 1 | |
) | |
self.in_channels = input_shape.channels | |
# fmt: on | |
if not self.image_feature_enabled: | |
self.in_channels = 0 | |
if self.positional_encoding_enabled: | |
self.in_channels += 256 | |
self.register_buffer("positional_encoding_gaussian_matrix", torch.randn((2, 128))) | |
assert self.in_channels > 0 | |
num_weight_params, num_bias_params = [], [] | |
assert self.num_layers >= 2 | |
for l in range(self.num_layers): | |
if l == 0: | |
# input layer | |
num_weight_params.append(self.in_channels * self.channels) | |
num_bias_params.append(self.channels) | |
elif l == self.num_layers - 1: | |
# output layer | |
num_weight_params.append(self.channels * self.num_classes) | |
num_bias_params.append(self.num_classes) | |
else: | |
# intermediate layer | |
num_weight_params.append(self.channels * self.channels) | |
num_bias_params.append(self.channels) | |
self.num_weight_params = num_weight_params | |
self.num_bias_params = num_bias_params | |
self.num_params = sum(num_weight_params) + sum(num_bias_params) | |
def forward(self, fine_grained_features, point_coords, parameters): | |
# features: [R, channels, K] | |
# point_coords: [R, K, 2] | |
num_instances = fine_grained_features.size(0) | |
num_points = fine_grained_features.size(2) | |
if num_instances == 0: | |
return torch.zeros((0, 1, num_points), device=fine_grained_features.device) | |
if self.positional_encoding_enabled: | |
# locations: [R*K, 2] | |
locations = 2 * point_coords.reshape(num_instances * num_points, 2) - 1 | |
locations = locations @ self.positional_encoding_gaussian_matrix.to(locations.device) | |
locations = 2 * np.pi * locations | |
locations = torch.cat([torch.sin(locations), torch.cos(locations)], dim=1) | |
# locations: [R, C, K] | |
locations = locations.reshape(num_instances, num_points, 256).permute(0, 2, 1) | |
if not self.image_feature_enabled: | |
fine_grained_features = locations | |
else: | |
fine_grained_features = torch.cat([locations, fine_grained_features], dim=1) | |
# features [R, C, K] | |
mask_feat = fine_grained_features.reshape(num_instances, self.in_channels, num_points) | |
weights, biases = self._parse_params( | |
parameters, | |
self.in_channels, | |
self.channels, | |
self.num_classes, | |
self.num_weight_params, | |
self.num_bias_params, | |
) | |
point_logits = self._dynamic_mlp(mask_feat, weights, biases, num_instances) | |
point_logits = point_logits.reshape(-1, self.num_classes, num_points) | |
return point_logits | |
def _dynamic_mlp(features, weights, biases, num_instances): | |
assert features.dim() == 3, features.dim() | |
n_layers = len(weights) | |
x = features | |
for i, (w, b) in enumerate(zip(weights, biases)): | |
x = torch.einsum("nck,ndc->ndk", x, w) + b | |
if i < n_layers - 1: | |
x = F.relu(x) | |
return x | |
def _parse_params( | |
pred_params, | |
in_channels, | |
channels, | |
num_classes, | |
num_weight_params, | |
num_bias_params, | |
): | |
assert pred_params.dim() == 2 | |
assert len(num_weight_params) == len(num_bias_params) | |
assert pred_params.size(1) == sum(num_weight_params) + sum(num_bias_params) | |
num_instances = pred_params.size(0) | |
num_layers = len(num_weight_params) | |
params_splits = list( | |
torch.split_with_sizes(pred_params, num_weight_params + num_bias_params, dim=1) | |
) | |
weight_splits = params_splits[:num_layers] | |
bias_splits = params_splits[num_layers:] | |
for l in range(num_layers): | |
if l == 0: | |
# input layer | |
weight_splits[l] = weight_splits[l].reshape(num_instances, channels, in_channels) | |
bias_splits[l] = bias_splits[l].reshape(num_instances, channels, 1) | |
elif l < num_layers - 1: | |
# intermediate layer | |
weight_splits[l] = weight_splits[l].reshape(num_instances, channels, channels) | |
bias_splits[l] = bias_splits[l].reshape(num_instances, channels, 1) | |
else: | |
# output layer | |
weight_splits[l] = weight_splits[l].reshape(num_instances, num_classes, channels) | |
bias_splits[l] = bias_splits[l].reshape(num_instances, num_classes, 1) | |
return weight_splits, bias_splits | |
def build_point_head(cfg, input_channels): | |
""" | |
Build a point head defined by `cfg.MODEL.POINT_HEAD.NAME`. | |
""" | |
head_name = cfg.MODEL.POINT_HEAD.NAME | |
return POINT_HEAD_REGISTRY.get(head_name)(cfg, input_channels) | |