Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# pyre-unsafe | |
from typing import Any, List | |
import torch | |
from torch.nn import functional as F | |
from detectron2.config import CfgNode | |
from detectron2.structures import Instances | |
from .utils import resample_data | |
class SegmentationLoss: | |
""" | |
Segmentation loss as cross-entropy for raw unnormalized scores given ground truth | |
labels. Segmentation ground truth labels are defined for the bounding box of | |
interest at some fixed resolution [S, S], where | |
S = MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE. | |
""" | |
def __init__(self, cfg: CfgNode): | |
""" | |
Initialize segmentation loss from configuration options | |
Args: | |
cfg (CfgNode): configuration options | |
""" | |
self.heatmap_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE | |
self.n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS | |
def __call__( | |
self, | |
proposals_with_gt: List[Instances], | |
densepose_predictor_outputs: Any, | |
packed_annotations: Any, | |
) -> torch.Tensor: | |
""" | |
Compute segmentation loss as cross-entropy on aligned segmentation | |
ground truth and estimated scores. | |
Args: | |
proposals_with_gt (list of Instances): detections with associated ground truth data | |
densepose_predictor_outputs: an object of a dataclass that contains predictor outputs | |
with estimated values; assumed to have the following attributes: | |
* coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S] | |
packed_annotations: packed annotations for efficient loss computation; | |
the following attributes are used: | |
- coarse_segm_gt | |
- bbox_xywh_gt | |
- bbox_xywh_est | |
""" | |
if packed_annotations.coarse_segm_gt is None: | |
return self.fake_value(densepose_predictor_outputs) | |
coarse_segm_est = densepose_predictor_outputs.coarse_segm[packed_annotations.bbox_indices] | |
with torch.no_grad(): | |
coarse_segm_gt = resample_data( | |
packed_annotations.coarse_segm_gt.unsqueeze(1), | |
packed_annotations.bbox_xywh_gt, | |
packed_annotations.bbox_xywh_est, | |
self.heatmap_size, | |
self.heatmap_size, | |
mode="nearest", | |
padding_mode="zeros", | |
).squeeze(1) | |
if self.n_segm_chan == 2: | |
coarse_segm_gt = coarse_segm_gt > 0 | |
return F.cross_entropy(coarse_segm_est, coarse_segm_gt.long()) | |
def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor: | |
""" | |
Fake segmentation loss used when no suitable ground truth data | |
was found in a batch. The loss has a value 0 and is primarily used to | |
construct the computation graph, so that `DistributedDataParallel` | |
has similar graphs on all GPUs and can perform reduction properly. | |
Args: | |
densepose_predictor_outputs: DensePose predictor outputs, an object | |
of a dataclass that is assumed to have `coarse_segm` | |
attribute | |
Return: | |
Zero value loss with proper computation graph | |
""" | |
return densepose_predictor_outputs.coarse_segm.sum() * 0 | |