File size: 5,446 Bytes
3f9659e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from dataclasses import dataclass
from typing import Any, Iterable, List, Optional
import torch
from torch.nn import functional as F

from detectron2.structures import Instances


@dataclass
class DataForMaskLoss:
    """

    Contains mask GT and estimated data for proposals from multiple images:

    """

    # tensor of size (K, H, W) containing GT labels
    masks_gt: Optional[torch.Tensor] = None
    # tensor of size (K, C, H, W) containing estimated scores
    masks_est: Optional[torch.Tensor] = None


def extract_data_for_mask_loss_from_matches(

    proposals_targets: Iterable[Instances], estimated_segm: torch.Tensor

) -> DataForMaskLoss:
    """

    Extract data for mask loss from instances that contain matched GT and

    estimated bounding boxes.

    Args:

        proposals_targets: Iterable[Instances]

            matched GT and estimated results, each item in the iterable

            corresponds to data in 1 image

        estimated_segm: tensor(K, C, S, S) of float - raw unnormalized

            segmentation scores, here S is the size to which GT masks are

            to be resized

    Return:

        masks_est: tensor(K, C, S, S) of float - class scores

        masks_gt: tensor(K, S, S) of int64 - labels

    """
    data = DataForMaskLoss()
    masks_gt = []
    offset = 0
    assert estimated_segm.shape[2] == estimated_segm.shape[3], (
        f"Expected estimated segmentation to have a square shape, "
        f"but the actual shape is {estimated_segm.shape[2:]}"
    )
    mask_size = estimated_segm.shape[2]
    num_proposals = sum(inst.proposal_boxes.tensor.size(0) for inst in proposals_targets)
    num_estimated = estimated_segm.shape[0]
    assert (
        num_proposals == num_estimated
    ), "The number of proposals {} must be equal to the number of estimates {}".format(
        num_proposals, num_estimated
    )

    for proposals_targets_per_image in proposals_targets:
        n_i = proposals_targets_per_image.proposal_boxes.tensor.size(0)
        if not n_i:
            continue
        gt_masks_per_image = proposals_targets_per_image.gt_masks.crop_and_resize(
            proposals_targets_per_image.proposal_boxes.tensor, mask_size
        ).to(device=estimated_segm.device)
        masks_gt.append(gt_masks_per_image)
        offset += n_i
    if masks_gt:
        data.masks_est = estimated_segm
        data.masks_gt = torch.cat(masks_gt, dim=0)
    return data


class MaskLoss:
    """

    Mask loss as cross-entropy for raw unnormalized scores given ground truth labels.

    Mask ground truth labels are defined for the whole image and not only the

    bounding box of interest. They are stored as objects that are assumed to implement

    the `crop_and_resize` interface (e.g. BitMasks, PolygonMasks).

    """

    def __call__(

        self, proposals_with_gt: List[Instances], densepose_predictor_outputs: Any

    ) -> torch.Tensor:
        """

        Computes segmentation loss as cross-entropy for raw unnormalized

        scores given ground truth labels.



        Args:

            proposals_with_gt (list of Instances): detections with associated ground truth data

            densepose_predictor_outputs: an object of a dataclass that contains predictor outputs

                with estimated values; assumed to have the following attribute:

                * coarse_segm (tensor of shape [N, D, S, S]): coarse segmentation estimates

                    as raw unnormalized scores

                where N is the number of detections, S is the estimate size ( = width = height)

                and D is the number of coarse segmentation channels.

        Return:

            Cross entropy for raw unnormalized scores for coarse segmentation given

            ground truth labels from masks

        """
        if not len(proposals_with_gt):
            return self.fake_value(densepose_predictor_outputs)
        # densepose outputs are computed for all images and all bounding boxes;
        # i.e. if a batch has 4 images with (3, 1, 2, 1) proposals respectively,
        # the outputs will have size(0) == 3+1+2+1 == 7
        with torch.no_grad():
            mask_loss_data = extract_data_for_mask_loss_from_matches(
                proposals_with_gt, densepose_predictor_outputs.coarse_segm
            )
        if (mask_loss_data.masks_gt is None) or (mask_loss_data.masks_est is None):
            return self.fake_value(densepose_predictor_outputs)
        return F.cross_entropy(mask_loss_data.masks_est, mask_loss_data.masks_gt.long())

    def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor:
        """

        Fake segmentation loss used when no suitable ground truth data

        was found in a batch. The loss has a value 0 and is primarily used to

        construct the computation graph, so that `DistributedDataParallel`

        has similar graphs on all GPUs and can perform reduction properly.



        Args:

            densepose_predictor_outputs: DensePose predictor outputs, an object

                of a dataclass that is assumed to have `coarse_segm`

                attribute

        Return:

            Zero value loss with proper computation graph

        """
        return densepose_predictor_outputs.coarse_segm.sum() * 0