File size: 6,190 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor


def flip_heatmaps(heatmaps: Tensor,
                  flip_indices: Optional[List[int]] = None,
                  flip_mode: str = 'heatmap',
                  shift_heatmap: bool = True):
    """Flip heatmaps for test-time augmentation.

    Args:
        heatmaps (Tensor): The heatmaps to flip. Should be a tensor in shape
            [B, C, H, W]
        flip_indices (List[int]): The indices of each keypoint's symmetric
            keypoint. Defaults to ``None``
        flip_mode (str): Specify the flipping mode. Options are:

            - ``'heatmap'``: horizontally flip the heatmaps and swap heatmaps
                of symmetric keypoints according to ``flip_indices``
            - ``'udp_combined'``: similar to ``'heatmap'`` mode but further
                flip the x_offset values
            - ``'offset'``: horizontally flip the offset fields and swap
                heatmaps of symmetric keypoints according to
                ``flip_indices``. x_offset values are also reversed
        shift_heatmap (bool): Shift the flipped heatmaps to align with the
            original heatmaps and improve accuracy. Defaults to ``True``

    Returns:
        Tensor: flipped heatmaps in shape [B, C, H, W]
    """

    if flip_mode == 'heatmap':
        heatmaps = heatmaps.flip(-1)
        if flip_indices is not None:
            assert len(flip_indices) == heatmaps.shape[1]
            heatmaps = heatmaps[:, flip_indices]
    elif flip_mode == 'udp_combined':
        B, C, H, W = heatmaps.shape
        heatmaps = heatmaps.view(B, C // 3, 3, H, W)
        heatmaps = heatmaps.flip(-1)
        if flip_indices is not None:
            assert len(flip_indices) == C // 3
            heatmaps = heatmaps[:, flip_indices]
        heatmaps[:, :, 1] = -heatmaps[:, :, 1]
        heatmaps = heatmaps.view(B, C, H, W)

    elif flip_mode == 'offset':
        B, C, H, W = heatmaps.shape
        heatmaps = heatmaps.view(B, C // 2, -1, H, W)
        heatmaps = heatmaps.flip(-1)
        if flip_indices is not None:
            assert len(flip_indices) == C // 2
            heatmaps = heatmaps[:, flip_indices]
        heatmaps[:, :, 0] = -heatmaps[:, :, 0]
        heatmaps = heatmaps.view(B, C, H, W)

    else:
        raise ValueError(f'Invalid flip_mode value "{flip_mode}"')

    if shift_heatmap:
        # clone data to avoid unexpected in-place operation when using CPU
        heatmaps[..., 1:] = heatmaps[..., :-1].clone()

    return heatmaps


def flip_vectors(x_labels: Tensor, y_labels: Tensor, flip_indices: List[int]):
    """Flip instance-level labels in specific axis for test-time augmentation.

    Args:
        x_labels (Tensor): The vector labels in x-axis to flip. Should be
            a tensor in shape [B, C, Wx]
        y_labels (Tensor): The vector labels in y-axis to flip. Should be
            a tensor in shape [B, C, Wy]
        flip_indices (List[int]): The indices of each keypoint's symmetric
            keypoint
    """
    assert x_labels.ndim == 3 and y_labels.ndim == 3
    assert len(flip_indices) == x_labels.shape[1] and len(
        flip_indices) == y_labels.shape[1]
    x_labels = x_labels[:, flip_indices].flip(-1)
    y_labels = y_labels[:, flip_indices]

    return x_labels, y_labels


def flip_coordinates(coords: Tensor, flip_indices: List[int],
                     shift_coords: bool, input_size: Tuple[int, int]):
    """Flip normalized coordinates for test-time augmentation.

    Args:
        coords (Tensor): The coordinates to flip. Should be a tensor in shape
            [B, K, D]
        flip_indices (List[int]): The indices of each keypoint's symmetric
            keypoint
        shift_coords (bool): Shift the flipped coordinates to align with the
            original coordinates and improve accuracy. Defaults to ``True``
        input_size (Tuple[int, int]): The size of input image in [w, h]
    """
    assert coords.ndim == 3
    assert len(flip_indices) == coords.shape[1]

    coords[:, :, 0] = 1.0 - coords[:, :, 0]

    if shift_coords:
        img_width = input_size[0]
        coords[:, :, 0] -= 1.0 / img_width

    coords = coords[:, flip_indices]
    return coords


def aggregate_heatmaps(heatmaps: List[Tensor],
                       size: Optional[Tuple[int, int]],
                       align_corners: bool = False,
                       mode: str = 'average'):
    """Aggregate multiple heatmaps.

    Args:
        heatmaps (List[Tensor]): Multiple heatmaps to aggregate. Each should
            be in shape (B, C, H, W)
        size (Tuple[int, int], optional): The target size in (w, h). All
            heatmaps will be resized to the target size. If not given, the
            first heatmap tensor's width and height will be used as the target
            size. Defaults to ``None``
        align_corners (bool): Whether align corners when resizing heatmaps.
            Defaults to ``False``
        mode (str): Aggregation mode in one of the following:

            - ``'average'``: Get average of heatmaps. All heatmaps mush have
                the same channel number
            - ``'concat'``: Concate the heatmaps at the channel dim
    """

    if mode not in {'average', 'concat'}:
        raise ValueError(f'Invalid aggregation mode `{mode}`')

    if size is None:
        h, w = heatmaps[0].shape[2:4]
    else:
        w, h = size

    for i, _heatmaps in enumerate(heatmaps):
        assert _heatmaps.ndim == 4
        if mode == 'average':
            assert _heatmaps.shape[:2] == heatmaps[0].shape[:2]
        else:
            assert _heatmaps.shape[0] == heatmaps[0].shape[0]

        if _heatmaps.shape[2:4] != (h, w):
            heatmaps[i] = F.interpolate(
                _heatmaps,
                size=(h, w),
                mode='bilinear',
                align_corners=align_corners)

    if mode == 'average':
        output = sum(heatmaps).div(len(heatmaps))
    elif mode == 'concat':
        output = torch.cat(heatmaps, dim=1)
    else:
        raise ValueError()

    return output