Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
def flip_heatmaps(heatmaps: Tensor, | |
flip_indices: Optional[List[int]] = None, | |
flip_mode: str = 'heatmap', | |
shift_heatmap: bool = True): | |
"""Flip heatmaps for test-time augmentation. | |
Args: | |
heatmaps (Tensor): The heatmaps to flip. Should be a tensor in shape | |
[B, C, H, W] | |
flip_indices (List[int]): The indices of each keypoint's symmetric | |
keypoint. Defaults to ``None`` | |
flip_mode (str): Specify the flipping mode. Options are: | |
- ``'heatmap'``: horizontally flip the heatmaps and swap heatmaps | |
of symmetric keypoints according to ``flip_indices`` | |
- ``'udp_combined'``: similar to ``'heatmap'`` mode but further | |
flip the x_offset values | |
- ``'offset'``: horizontally flip the offset fields and swap | |
heatmaps of symmetric keypoints according to | |
``flip_indices``. x_offset values are also reversed | |
shift_heatmap (bool): Shift the flipped heatmaps to align with the | |
original heatmaps and improve accuracy. Defaults to ``True`` | |
Returns: | |
Tensor: flipped heatmaps in shape [B, C, H, W] | |
""" | |
if flip_mode == 'heatmap': | |
heatmaps = heatmaps.flip(-1) | |
if flip_indices is not None: | |
assert len(flip_indices) == heatmaps.shape[1] | |
heatmaps = heatmaps[:, flip_indices] | |
elif flip_mode == 'udp_combined': | |
B, C, H, W = heatmaps.shape | |
heatmaps = heatmaps.view(B, C // 3, 3, H, W) | |
heatmaps = heatmaps.flip(-1) | |
if flip_indices is not None: | |
assert len(flip_indices) == C // 3 | |
heatmaps = heatmaps[:, flip_indices] | |
heatmaps[:, :, 1] = -heatmaps[:, :, 1] | |
heatmaps = heatmaps.view(B, C, H, W) | |
elif flip_mode == 'offset': | |
B, C, H, W = heatmaps.shape | |
heatmaps = heatmaps.view(B, C // 2, -1, H, W) | |
heatmaps = heatmaps.flip(-1) | |
if flip_indices is not None: | |
assert len(flip_indices) == C // 2 | |
heatmaps = heatmaps[:, flip_indices] | |
heatmaps[:, :, 0] = -heatmaps[:, :, 0] | |
heatmaps = heatmaps.view(B, C, H, W) | |
else: | |
raise ValueError(f'Invalid flip_mode value "{flip_mode}"') | |
if shift_heatmap: | |
# clone data to avoid unexpected in-place operation when using CPU | |
heatmaps[..., 1:] = heatmaps[..., :-1].clone() | |
return heatmaps | |
def flip_vectors(x_labels: Tensor, y_labels: Tensor, flip_indices: List[int]): | |
"""Flip instance-level labels in specific axis for test-time augmentation. | |
Args: | |
x_labels (Tensor): The vector labels in x-axis to flip. Should be | |
a tensor in shape [B, C, Wx] | |
y_labels (Tensor): The vector labels in y-axis to flip. Should be | |
a tensor in shape [B, C, Wy] | |
flip_indices (List[int]): The indices of each keypoint's symmetric | |
keypoint | |
""" | |
assert x_labels.ndim == 3 and y_labels.ndim == 3 | |
assert len(flip_indices) == x_labels.shape[1] and len( | |
flip_indices) == y_labels.shape[1] | |
x_labels = x_labels[:, flip_indices].flip(-1) | |
y_labels = y_labels[:, flip_indices] | |
return x_labels, y_labels | |
def flip_coordinates(coords: Tensor, flip_indices: List[int], | |
shift_coords: bool, input_size: Tuple[int, int]): | |
"""Flip normalized coordinates for test-time augmentation. | |
Args: | |
coords (Tensor): The coordinates to flip. Should be a tensor in shape | |
[B, K, D] | |
flip_indices (List[int]): The indices of each keypoint's symmetric | |
keypoint | |
shift_coords (bool): Shift the flipped coordinates to align with the | |
original coordinates and improve accuracy. Defaults to ``True`` | |
input_size (Tuple[int, int]): The size of input image in [w, h] | |
""" | |
assert coords.ndim == 3 | |
assert len(flip_indices) == coords.shape[1] | |
coords[:, :, 0] = 1.0 - coords[:, :, 0] | |
if shift_coords: | |
img_width = input_size[0] | |
coords[:, :, 0] -= 1.0 / img_width | |
coords = coords[:, flip_indices] | |
return coords | |
def aggregate_heatmaps(heatmaps: List[Tensor], | |
size: Optional[Tuple[int, int]], | |
align_corners: bool = False, | |
mode: str = 'average'): | |
"""Aggregate multiple heatmaps. | |
Args: | |
heatmaps (List[Tensor]): Multiple heatmaps to aggregate. Each should | |
be in shape (B, C, H, W) | |
size (Tuple[int, int], optional): The target size in (w, h). All | |
heatmaps will be resized to the target size. If not given, the | |
first heatmap tensor's width and height will be used as the target | |
size. Defaults to ``None`` | |
align_corners (bool): Whether align corners when resizing heatmaps. | |
Defaults to ``False`` | |
mode (str): Aggregation mode in one of the following: | |
- ``'average'``: Get average of heatmaps. All heatmaps mush have | |
the same channel number | |
- ``'concat'``: Concate the heatmaps at the channel dim | |
""" | |
if mode not in {'average', 'concat'}: | |
raise ValueError(f'Invalid aggregation mode `{mode}`') | |
if size is None: | |
h, w = heatmaps[0].shape[2:4] | |
else: | |
w, h = size | |
for i, _heatmaps in enumerate(heatmaps): | |
assert _heatmaps.ndim == 4 | |
if mode == 'average': | |
assert _heatmaps.shape[:2] == heatmaps[0].shape[:2] | |
else: | |
assert _heatmaps.shape[0] == heatmaps[0].shape[0] | |
if _heatmaps.shape[2:4] != (h, w): | |
heatmaps[i] = F.interpolate( | |
_heatmaps, | |
size=(h, w), | |
mode='bilinear', | |
align_corners=align_corners) | |
if mode == 'average': | |
output = sum(heatmaps).div(len(heatmaps)) | |
elif mode == 'concat': | |
output = torch.cat(heatmaps, dim=1) | |
else: | |
raise ValueError() | |
return output | |