|
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Iterable, List, Optional
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.structures import Instances
|
|
|
|
|
|
@dataclass
|
|
class DataForMaskLoss:
|
|
"""
|
|
Contains mask GT and estimated data for proposals from multiple images:
|
|
"""
|
|
|
|
|
|
masks_gt: Optional[torch.Tensor] = None
|
|
|
|
masks_est: Optional[torch.Tensor] = None
|
|
|
|
|
|
def extract_data_for_mask_loss_from_matches(
|
|
proposals_targets: Iterable[Instances], estimated_segm: torch.Tensor
|
|
) -> DataForMaskLoss:
|
|
"""
|
|
Extract data for mask loss from instances that contain matched GT and
|
|
estimated bounding boxes.
|
|
Args:
|
|
proposals_targets: Iterable[Instances]
|
|
matched GT and estimated results, each item in the iterable
|
|
corresponds to data in 1 image
|
|
estimated_segm: tensor(K, C, S, S) of float - raw unnormalized
|
|
segmentation scores, here S is the size to which GT masks are
|
|
to be resized
|
|
Return:
|
|
masks_est: tensor(K, C, S, S) of float - class scores
|
|
masks_gt: tensor(K, S, S) of int64 - labels
|
|
"""
|
|
data = DataForMaskLoss()
|
|
masks_gt = []
|
|
offset = 0
|
|
assert estimated_segm.shape[2] == estimated_segm.shape[3], (
|
|
f"Expected estimated segmentation to have a square shape, "
|
|
f"but the actual shape is {estimated_segm.shape[2:]}"
|
|
)
|
|
mask_size = estimated_segm.shape[2]
|
|
num_proposals = sum(inst.proposal_boxes.tensor.size(0) for inst in proposals_targets)
|
|
num_estimated = estimated_segm.shape[0]
|
|
assert (
|
|
num_proposals == num_estimated
|
|
), "The number of proposals {} must be equal to the number of estimates {}".format(
|
|
num_proposals, num_estimated
|
|
)
|
|
|
|
for proposals_targets_per_image in proposals_targets:
|
|
n_i = proposals_targets_per_image.proposal_boxes.tensor.size(0)
|
|
if not n_i:
|
|
continue
|
|
gt_masks_per_image = proposals_targets_per_image.gt_masks.crop_and_resize(
|
|
proposals_targets_per_image.proposal_boxes.tensor, mask_size
|
|
).to(device=estimated_segm.device)
|
|
masks_gt.append(gt_masks_per_image)
|
|
offset += n_i
|
|
if masks_gt:
|
|
data.masks_est = estimated_segm
|
|
data.masks_gt = torch.cat(masks_gt, dim=0)
|
|
return data
|
|
|
|
|
|
class MaskLoss:
|
|
"""
|
|
Mask loss as cross-entropy for raw unnormalized scores given ground truth labels.
|
|
Mask ground truth labels are defined for the whole image and not only the
|
|
bounding box of interest. They are stored as objects that are assumed to implement
|
|
the `crop_and_resize` interface (e.g. BitMasks, PolygonMasks).
|
|
"""
|
|
|
|
def __call__(
|
|
self, proposals_with_gt: List[Instances], densepose_predictor_outputs: Any
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes segmentation loss as cross-entropy for raw unnormalized
|
|
scores given ground truth labels.
|
|
|
|
Args:
|
|
proposals_with_gt (list of Instances): detections with associated ground truth data
|
|
densepose_predictor_outputs: an object of a dataclass that contains predictor outputs
|
|
with estimated values; assumed to have the following attribute:
|
|
* coarse_segm (tensor of shape [N, D, S, S]): coarse segmentation estimates
|
|
as raw unnormalized scores
|
|
where N is the number of detections, S is the estimate size ( = width = height)
|
|
and D is the number of coarse segmentation channels.
|
|
Return:
|
|
Cross entropy for raw unnormalized scores for coarse segmentation given
|
|
ground truth labels from masks
|
|
"""
|
|
if not len(proposals_with_gt):
|
|
return self.fake_value(densepose_predictor_outputs)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
mask_loss_data = extract_data_for_mask_loss_from_matches(
|
|
proposals_with_gt, densepose_predictor_outputs.coarse_segm
|
|
)
|
|
if (mask_loss_data.masks_gt is None) or (mask_loss_data.masks_est is None):
|
|
return self.fake_value(densepose_predictor_outputs)
|
|
return F.cross_entropy(mask_loss_data.masks_est, mask_loss_data.masks_gt.long())
|
|
|
|
def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor:
|
|
"""
|
|
Fake segmentation loss used when no suitable ground truth data
|
|
was found in a batch. The loss has a value 0 and is primarily used to
|
|
construct the computation graph, so that `DistributedDataParallel`
|
|
has similar graphs on all GPUs and can perform reduction properly.
|
|
|
|
Args:
|
|
densepose_predictor_outputs: DensePose predictor outputs, an object
|
|
of a dataclass that is assumed to have `coarse_segm`
|
|
attribute
|
|
Return:
|
|
Zero value loss with proper computation graph
|
|
"""
|
|
return densepose_predictor_outputs.coarse_segm.sum() * 0
|
|
|