File size: 6,446 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple

import numpy as np
import torch

from mmcls.registry import BATCH_AUGMENTS
from .mixup import Mixup


@BATCH_AUGMENTS.register_module()
class CutMix(Mixup):
    r"""CutMix batch agumentation.

    CutMix is a method to improve the network's generalization capability. It's
    proposed in `CutMix: Regularization Strategy to Train Strong Classifiers
    with Localizable Features <https://arxiv.org/abs/1905.04899>`

    With this method, patches are cut and pasted among training images where
    the ground truth labels are also mixed proportionally to the area of the
    patches.

    Args:
        alpha (float): Parameters for Beta distribution to generate the
            mixing ratio. It should be a positive number. More details
            can be found in :class:`Mixup`.
        cutmix_minmax (List[float], optional): The min/max area ratio of the
            patches. If not None, the bounding-box of patches is uniform
            sampled within this ratio range, and the ``alpha`` will be ignored.
            Otherwise, the bounding-box is generated according to the
            ``alpha``. Defaults to None.
        correct_lam (bool): Whether to apply lambda correction when cutmix bbox
            clipped by image borders. Defaults to True.

    .. note ::
        If the ``cutmix_minmax`` is None, how to generate the bounding-box of
        patches according to the ``alpha``?

        First, generate a :math:`\lambda`, details can be found in
        :class:`Mixup`. And then, the area ratio of the bounding-box
        is calculated by:

        .. math::
            \text{ratio} = \sqrt{1-\lambda}
    """

    def __init__(self,
                 alpha: float,
                 cutmix_minmax: Optional[List[float]] = None,
                 correct_lam: bool = True):
        super().__init__(alpha=alpha)

        self.cutmix_minmax = cutmix_minmax
        self.correct_lam = correct_lam

    def rand_bbox_minmax(
            self,
            img_shape: Tuple[int, int],
            count: Optional[int] = None) -> Tuple[int, int, int, int]:
        """Min-Max CutMix bounding-box Inspired by Darknet cutmix
        implementation. It generates a random rectangular bbox based on min/max
        percent values applied to each dimension of the input image.

        Typical defaults for minmax are usually in the  .2-.3 for min and
        .8-.9 range for max.

        Args:
            img_shape (tuple): Image shape as tuple
            count (int, optional): Number of bbox to generate. Defaults to None
        """
        assert len(self.cutmix_minmax) == 2
        img_h, img_w = img_shape
        cut_h = np.random.randint(
            int(img_h * self.cutmix_minmax[0]),
            int(img_h * self.cutmix_minmax[1]),
            size=count)
        cut_w = np.random.randint(
            int(img_w * self.cutmix_minmax[0]),
            int(img_w * self.cutmix_minmax[1]),
            size=count)
        yl = np.random.randint(0, img_h - cut_h, size=count)
        xl = np.random.randint(0, img_w - cut_w, size=count)
        yu = yl + cut_h
        xu = xl + cut_w
        return yl, yu, xl, xu

    def rand_bbox(self,
                  img_shape: Tuple[int, int],
                  lam: float,
                  margin: float = 0.,
                  count: Optional[int] = None) -> Tuple[int, int, int, int]:
        """Standard CutMix bounding-box that generates a random square bbox
        based on lambda value. This implementation includes support for
        enforcing a border margin as percent of bbox dimensions.

        Args:
            img_shape (tuple): Image shape as tuple
            lam (float): Cutmix lambda value
            margin (float): Percentage of bbox dimension to enforce as margin
                (reduce amount of box outside image). Defaults to 0.
            count (int, optional): Number of bbox to generate. Defaults to None
        """
        ratio = np.sqrt(1 - lam)
        img_h, img_w = img_shape
        cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
        margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
        cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
        cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
        yl = np.clip(cy - cut_h // 2, 0, img_h)
        yh = np.clip(cy + cut_h // 2, 0, img_h)
        xl = np.clip(cx - cut_w // 2, 0, img_w)
        xh = np.clip(cx + cut_w // 2, 0, img_w)
        return yl, yh, xl, xh

    def cutmix_bbox_and_lam(self,
                            img_shape: Tuple[int, int],
                            lam: float,
                            count: Optional[int] = None) -> tuple:
        """Generate bbox and apply lambda correction.

        Args:
            img_shape (tuple): Image shape as tuple
            lam (float): Cutmix lambda value
            count (int, optional): Number of bbox to generate. Defaults to None
        """
        if self.cutmix_minmax is not None:
            yl, yu, xl, xu = self.rand_bbox_minmax(img_shape, count=count)
        else:
            yl, yu, xl, xu = self.rand_bbox(img_shape, lam, count=count)
        if self.correct_lam or self.cutmix_minmax is not None:
            bbox_area = (yu - yl) * (xu - xl)
            lam = 1. - bbox_area / float(img_shape[0] * img_shape[1])
        return (yl, yu, xl, xu), lam

    def mix(self, batch_inputs: torch.Tensor,
            batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Mix the batch inputs and batch one-hot format ground truth.

        Args:
            batch_inputs (Tensor): A batch of images tensor in the shape of
                ``(N, C, H, W)``.
            batch_scores (Tensor): A batch of one-hot format labels in the
                shape of ``(N, num_classes)``.

        Returns:
            Tuple[Tensor, Tensor): The mixed inputs and labels.
        """
        lam = np.random.beta(self.alpha, self.alpha)
        batch_size = batch_inputs.size(0)
        img_shape = batch_inputs.shape[-2:]
        index = torch.randperm(batch_size)

        (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam)
        batch_inputs[:, :, y1:y2, x1:x2] = batch_inputs[index, :, y1:y2, x1:x2]
        mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :]

        return batch_inputs, mixed_scores