File size: 3,042 Bytes
3f9659e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, List
import torch

from detectron2.config import CfgNode
from detectron2.structures import Instances

from .mask import MaskLoss
from .segm import SegmentationLoss


class MaskOrSegmentationLoss:
    """

    Mask or segmentation loss as cross-entropy for raw unnormalized scores

    given ground truth labels. Ground truth labels are either defined by coarse

    segmentation annotation, or by mask annotation, depending on the config

    value MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS

    """

    def __init__(self, cfg: CfgNode):
        """

        Initialize segmentation loss from configuration options



        Args:

            cfg (CfgNode): configuration options

        """
        self.segm_trained_by_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
        if self.segm_trained_by_masks:
            self.mask_loss = MaskLoss()
        self.segm_loss = SegmentationLoss(cfg)

    def __call__(

        self,

        proposals_with_gt: List[Instances],

        densepose_predictor_outputs: Any,

        packed_annotations: Any,

    ) -> torch.Tensor:
        """

        Compute segmentation loss as cross-entropy between aligned unnormalized

        score estimates and ground truth; with ground truth given

        either by masks, or by coarse segmentation annotations.



        Args:

            proposals_with_gt (list of Instances): detections with associated ground truth data

            densepose_predictor_outputs: an object of a dataclass that contains predictor outputs

                with estimated values; assumed to have the following attributes:

                * coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S]

            packed_annotations: packed annotations for efficient loss computation

        Return:

            tensor: loss value as cross-entropy for raw unnormalized scores

                given ground truth labels

        """
        if self.segm_trained_by_masks:
            return self.mask_loss(proposals_with_gt, densepose_predictor_outputs)
        return self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations)

    def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor:
        """

        Fake segmentation loss used when no suitable ground truth data

        was found in a batch. The loss has a value 0 and is primarily used to

        construct the computation graph, so that `DistributedDataParallel`

        has similar graphs on all GPUs and can perform reduction properly.



        Args:

            densepose_predictor_outputs: DensePose predictor outputs, an object

                of a dataclass that is assumed to have `coarse_segm`

                attribute

        Return:

            Zero value loss with proper computation graph

        """
        return densepose_predictor_outputs.coarse_segm.sum() * 0